1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // Unit Test for MediaSampleReaderNDK
18
19 // #define LOG_NDEBUG 0
20 #define LOG_TAG "MediaSampleReaderNDKTests"
21
22 #include <android-base/logging.h>
23 #include <android/binder_manager.h>
24 #include <android/binder_process.h>
25 #include <fcntl.h>
26 #include <gtest/gtest.h>
27 #include <media/MediaSampleReaderNDK.h>
28 #include <openssl/md5.h>
29 #include <utils/Timers.h>
30
31 #include <cmath>
32 #include <mutex>
33 #include <thread>
34
35 // TODO(b/153453392): Test more asset types (frame reordering?).
36
37 namespace android {
38
39 #define SEC_TO_USEC(s) ((s)*1000 * 1000)
40
41 /** Helper class for comparing sample data using checksums. */
42 class Sample {
43 public:
Sample(uint32_t flags,int64_t timestamp,size_t size,const uint8_t * buffer)44 Sample(uint32_t flags, int64_t timestamp, size_t size, const uint8_t* buffer)
45 : mFlags{flags}, mTimestamp{timestamp}, mSize{size} {
46 initChecksum(buffer);
47 }
48
Sample(AMediaExtractor * extractor)49 Sample(AMediaExtractor* extractor) {
50 mFlags = AMediaExtractor_getSampleFlags(extractor);
51 mTimestamp = AMediaExtractor_getSampleTime(extractor);
52 mSize = static_cast<size_t>(AMediaExtractor_getSampleSize(extractor));
53
54 auto buffer = std::make_unique<uint8_t[]>(mSize);
55 AMediaExtractor_readSampleData(extractor, buffer.get(), mSize);
56
57 initChecksum(buffer.get());
58 }
59
initChecksum(const uint8_t * buffer)60 void initChecksum(const uint8_t* buffer) {
61 MD5_CTX md5Ctx;
62 MD5_Init(&md5Ctx);
63 MD5_Update(&md5Ctx, buffer, mSize);
64 MD5_Final(mChecksum, &md5Ctx);
65 }
66
operator ==(const Sample & rhs) const67 bool operator==(const Sample& rhs) const {
68 return mSize == rhs.mSize && mFlags == rhs.mFlags && mTimestamp == rhs.mTimestamp &&
69 memcmp(mChecksum, rhs.mChecksum, MD5_DIGEST_LENGTH) == 0;
70 }
71
72 uint32_t mFlags;
73 int64_t mTimestamp;
74 size_t mSize;
75 uint8_t mChecksum[MD5_DIGEST_LENGTH];
76 };
77
78 /** Constant for selecting all samples. */
79 static constexpr int SAMPLE_COUNT_ALL = -1;
80
81 /**
82 * Utility class to test different sample access patterns combined with sequential or parallel
83 * sample access modes.
84 */
85 class SampleAccessTester {
86 public:
SampleAccessTester(int sourceFd,size_t fileSize)87 SampleAccessTester(int sourceFd, size_t fileSize) {
88 mSampleReader = MediaSampleReaderNDK::createFromFd(sourceFd, 0, fileSize);
89 EXPECT_TRUE(mSampleReader);
90
91 mTrackCount = mSampleReader->getTrackCount();
92
93 for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
94 EXPECT_EQ(mSampleReader->selectTrack(trackIndex), AMEDIA_OK);
95 }
96
97 mSamples.resize(mTrackCount);
98 mTrackThreads.resize(mTrackCount);
99 }
100
getSampleInfo(int trackIndex)101 void getSampleInfo(int trackIndex) {
102 MediaSampleInfo info;
103 media_status_t status = mSampleReader->getSampleInfoForTrack(trackIndex, &info);
104 EXPECT_EQ(status, AMEDIA_OK);
105 }
106
readSamplesAsync(int trackIndex,int sampleCount)107 void readSamplesAsync(int trackIndex, int sampleCount) {
108 mTrackThreads[trackIndex] = std::thread{[this, trackIndex, sampleCount] {
109 int samplesRead = 0;
110 MediaSampleInfo info;
111 while (samplesRead < sampleCount || sampleCount == SAMPLE_COUNT_ALL) {
112 media_status_t status = mSampleReader->getSampleInfoForTrack(trackIndex, &info);
113 if (status != AMEDIA_OK) {
114 EXPECT_EQ(status, AMEDIA_ERROR_END_OF_STREAM);
115 EXPECT_TRUE((info.flags & SAMPLE_FLAG_END_OF_STREAM) != 0);
116 break;
117 }
118 ASSERT_TRUE((info.flags & SAMPLE_FLAG_END_OF_STREAM) == 0);
119
120 auto buffer = std::make_unique<uint8_t[]>(info.size);
121 status = mSampleReader->readSampleDataForTrack(trackIndex, buffer.get(), info.size);
122 EXPECT_EQ(status, AMEDIA_OK);
123
124 mSampleMutex.lock();
125 const uint8_t* bufferPtr = buffer.get();
126 mSamples[trackIndex].emplace_back(info.flags, info.presentationTimeUs, info.size,
127 bufferPtr);
128 mSampleMutex.unlock();
129 ++samplesRead;
130 }
131 }};
132 }
133
readSamplesAsync(int sampleCount)134 void readSamplesAsync(int sampleCount) {
135 for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
136 readSamplesAsync(trackIndex, sampleCount);
137 }
138 }
139
waitForTrack(int trackIndex)140 void waitForTrack(int trackIndex) {
141 ASSERT_TRUE(mTrackThreads[trackIndex].joinable());
142 mTrackThreads[trackIndex].join();
143 }
144
waitForTracks()145 void waitForTracks() {
146 for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
147 waitForTrack(trackIndex);
148 }
149 }
150
setEnforceSequentialAccess(bool enforce)151 void setEnforceSequentialAccess(bool enforce) {
152 media_status_t status = mSampleReader->setEnforceSequentialAccess(enforce);
153 EXPECT_EQ(status, AMEDIA_OK);
154 }
155
getSamples()156 std::vector<std::vector<Sample>>& getSamples() { return mSamples; }
157
158 std::shared_ptr<MediaSampleReader> mSampleReader;
159 size_t mTrackCount;
160 std::mutex mSampleMutex;
161 std::vector<std::thread> mTrackThreads;
162 std::vector<std::vector<Sample>> mSamples;
163 };
164
165 class MediaSampleReaderNDKTests : public ::testing::Test {
166 public:
MediaSampleReaderNDKTests()167 MediaSampleReaderNDKTests() { LOG(DEBUG) << "MediaSampleReaderNDKTests created"; }
168
SetUp()169 void SetUp() override {
170 LOG(DEBUG) << "MediaSampleReaderNDKTests set up";
171
172 // Need to start a thread pool to prevent AMediaExtractor binder calls from starving
173 // (b/155663561).
174 ABinderProcess_startThreadPool();
175
176 const char* sourcePath =
177 "/data/local/tmp/TranscodingTestAssets/cubicle_avc_480x240_aac_24KHz.mp4";
178
179 mSourceFd = open(sourcePath, O_RDONLY);
180 ASSERT_GT(mSourceFd, 0);
181
182 mFileSize = lseek(mSourceFd, 0, SEEK_END);
183 lseek(mSourceFd, 0, SEEK_SET);
184
185 mExtractor = AMediaExtractor_new();
186 ASSERT_NE(mExtractor, nullptr);
187
188 media_status_t status =
189 AMediaExtractor_setDataSourceFd(mExtractor, mSourceFd, 0, mFileSize);
190 ASSERT_EQ(status, AMEDIA_OK);
191
192 mTrackCount = AMediaExtractor_getTrackCount(mExtractor);
193 for (size_t trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
194 AMediaExtractor_selectTrack(mExtractor, trackIndex);
195 }
196 }
197
initExtractorSamples()198 void initExtractorSamples() {
199 if (mExtractorSamples.size() == mTrackCount) return;
200
201 // Save sample information, per track, as reported by the extractor.
202 mExtractorSamples.resize(mTrackCount);
203 do {
204 const int trackIndex = AMediaExtractor_getSampleTrackIndex(mExtractor);
205 mExtractorSamples[trackIndex].emplace_back(mExtractor);
206 } while (AMediaExtractor_advance(mExtractor));
207
208 AMediaExtractor_seekTo(mExtractor, 0, AMEDIAEXTRACTOR_SEEK_PREVIOUS_SYNC);
209 }
210
getTrackBitrates()211 std::vector<int32_t> getTrackBitrates() {
212 size_t totalSize[mTrackCount];
213 memset(totalSize, 0, sizeof(totalSize));
214
215 do {
216 const int trackIndex = AMediaExtractor_getSampleTrackIndex(mExtractor);
217 totalSize[trackIndex] += AMediaExtractor_getSampleSize(mExtractor);
218 } while (AMediaExtractor_advance(mExtractor));
219
220 AMediaExtractor_seekTo(mExtractor, 0, AMEDIAEXTRACTOR_SEEK_PREVIOUS_SYNC);
221
222 std::vector<int32_t> bitrates;
223 for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
224 int64_t durationUs;
225 AMediaFormat* trackFormat = AMediaExtractor_getTrackFormat(mExtractor, trackIndex);
226 EXPECT_NE(trackFormat, nullptr);
227 EXPECT_TRUE(AMediaFormat_getInt64(trackFormat, AMEDIAFORMAT_KEY_DURATION, &durationUs));
228 bitrates.push_back(roundf((float)totalSize[trackIndex] * 8 * 1000000 / durationUs));
229 }
230
231 return bitrates;
232 }
233
compareSamples(std::vector<std::vector<Sample>> & readerSamples)234 void compareSamples(std::vector<std::vector<Sample>>& readerSamples) {
235 initExtractorSamples();
236 EXPECT_EQ(readerSamples.size(), mTrackCount);
237
238 for (int trackIndex = 0; trackIndex < mTrackCount; trackIndex++) {
239 LOG(DEBUG) << "Track " << trackIndex << ", comparing "
240 << readerSamples[trackIndex].size() << " samples.";
241 EXPECT_EQ(readerSamples[trackIndex].size(), mExtractorSamples[trackIndex].size());
242 for (size_t sampleIndex = 0; sampleIndex < readerSamples[trackIndex].size();
243 sampleIndex++) {
244 EXPECT_EQ(readerSamples[trackIndex][sampleIndex],
245 mExtractorSamples[trackIndex][sampleIndex]);
246 }
247 }
248 }
249
TearDown()250 void TearDown() override {
251 LOG(DEBUG) << "MediaSampleReaderNDKTests tear down";
252 AMediaExtractor_delete(mExtractor);
253 close(mSourceFd);
254 }
255
~MediaSampleReaderNDKTests()256 ~MediaSampleReaderNDKTests() { LOG(DEBUG) << "MediaSampleReaderNDKTests destroyed"; }
257
258 AMediaExtractor* mExtractor = nullptr;
259 size_t mTrackCount;
260 int mSourceFd;
261 size_t mFileSize;
262 std::vector<std::vector<Sample>> mExtractorSamples;
263 };
264
265 /** Reads all samples from all tracks in parallel. */
TEST_F(MediaSampleReaderNDKTests,TestParallelSampleAccess)266 TEST_F(MediaSampleReaderNDKTests, TestParallelSampleAccess) {
267 LOG(DEBUG) << "TestParallelSampleAccess Starts";
268
269 SampleAccessTester tester{mSourceFd, mFileSize};
270 tester.readSamplesAsync(SAMPLE_COUNT_ALL);
271 tester.waitForTracks();
272 compareSamples(tester.getSamples());
273 }
274
275 /** Reads all samples except the last in each track, before finishing. */
TEST_F(MediaSampleReaderNDKTests,TestLastSampleBeforeEOS)276 TEST_F(MediaSampleReaderNDKTests, TestLastSampleBeforeEOS) {
277 LOG(DEBUG) << "TestLastSampleBeforeEOS Starts";
278 initExtractorSamples();
279
280 { // Natural track order
281 SampleAccessTester tester{mSourceFd, mFileSize};
282 for (int trackIndex = 0; trackIndex < mTrackCount; ++trackIndex) {
283 tester.readSamplesAsync(trackIndex, mExtractorSamples[trackIndex].size() - 1);
284 }
285 tester.waitForTracks();
286 for (int trackIndex = 0; trackIndex < mTrackCount; ++trackIndex) {
287 tester.readSamplesAsync(trackIndex, SAMPLE_COUNT_ALL);
288 tester.waitForTrack(trackIndex);
289 }
290 compareSamples(tester.getSamples());
291 }
292
293 { // Reverse track order
294 SampleAccessTester tester{mSourceFd, mFileSize};
295 for (int trackIndex = mTrackCount - 1; trackIndex >= 0; --trackIndex) {
296 tester.readSamplesAsync(trackIndex, mExtractorSamples[trackIndex].size() - 1);
297 }
298 tester.waitForTracks();
299 for (int trackIndex = mTrackCount - 1; trackIndex >= 0; --trackIndex) {
300 tester.readSamplesAsync(trackIndex, SAMPLE_COUNT_ALL);
301 tester.waitForTrack(trackIndex);
302 }
303 compareSamples(tester.getSamples());
304 }
305 }
306
307 /** Reads all samples from all tracks sequentially. */
TEST_F(MediaSampleReaderNDKTests,TestSequentialSampleAccess)308 TEST_F(MediaSampleReaderNDKTests, TestSequentialSampleAccess) {
309 LOG(DEBUG) << "TestSequentialSampleAccess Starts";
310
311 SampleAccessTester tester{mSourceFd, mFileSize};
312 tester.setEnforceSequentialAccess(true);
313 tester.readSamplesAsync(SAMPLE_COUNT_ALL);
314 tester.waitForTracks();
315 compareSamples(tester.getSamples());
316 }
317
318 /** Reads all samples from one track in parallel mode before switching to sequential mode. */
TEST_F(MediaSampleReaderNDKTests,TestMixedSampleAccessTrackEOS)319 TEST_F(MediaSampleReaderNDKTests, TestMixedSampleAccessTrackEOS) {
320 LOG(DEBUG) << "TestMixedSampleAccessTrackEOS Starts";
321
322 for (int readSampleInfoFlag = 0; readSampleInfoFlag <= 1; readSampleInfoFlag++) {
323 for (int trackIndToEOS = 0; trackIndToEOS < mTrackCount; ++trackIndToEOS) {
324 LOG(DEBUG) << "Testing EOS of track " << trackIndToEOS;
325
326 SampleAccessTester tester{mSourceFd, mFileSize};
327
328 // If the flag is set, read sample info from a different track before draining the track
329 // under test to force the reader to save the extractor position.
330 if (readSampleInfoFlag) {
331 tester.getSampleInfo((trackIndToEOS + 1) % mTrackCount);
332 }
333
334 // Read all samples from one track before enabling sequential access
335 tester.readSamplesAsync(trackIndToEOS, SAMPLE_COUNT_ALL);
336 tester.waitForTrack(trackIndToEOS);
337 tester.setEnforceSequentialAccess(true);
338
339 for (int trackIndex = 0; trackIndex < mTrackCount; ++trackIndex) {
340 if (trackIndex == trackIndToEOS) continue;
341
342 tester.readSamplesAsync(trackIndex, SAMPLE_COUNT_ALL);
343 tester.waitForTrack(trackIndex);
344 }
345
346 compareSamples(tester.getSamples());
347 }
348 }
349 }
350
351 /**
352 * Reads different combinations of sample counts from all tracks in parallel mode before switching
353 * to sequential mode and reading the rest of the samples.
354 */
TEST_F(MediaSampleReaderNDKTests,TestMixedSampleAccess)355 TEST_F(MediaSampleReaderNDKTests, TestMixedSampleAccess) {
356 LOG(DEBUG) << "TestMixedSampleAccess Starts";
357 initExtractorSamples();
358
359 for (int trackIndToTest = 0; trackIndToTest < mTrackCount; ++trackIndToTest) {
360 for (int sampleCount = 0; sampleCount <= (mExtractorSamples[trackIndToTest].size() + 1);
361 ++sampleCount) {
362 SampleAccessTester tester{mSourceFd, mFileSize};
363
364 for (int trackIndex = 0; trackIndex < mTrackCount; ++trackIndex) {
365 if (trackIndex == trackIndToTest) {
366 tester.readSamplesAsync(trackIndex, sampleCount);
367 } else {
368 tester.readSamplesAsync(trackIndex, mExtractorSamples[trackIndex].size() / 2);
369 }
370 }
371
372 tester.waitForTracks();
373 tester.setEnforceSequentialAccess(true);
374
375 tester.readSamplesAsync(SAMPLE_COUNT_ALL);
376 tester.waitForTracks();
377
378 compareSamples(tester.getSamples());
379 }
380 }
381 }
382
TEST_F(MediaSampleReaderNDKTests,TestEstimatedBitrateAccuracy)383 TEST_F(MediaSampleReaderNDKTests, TestEstimatedBitrateAccuracy) {
384 // Just put a somewhat reasonable upper bound on the estimated bitrate expected in our test
385 // assets. This is mostly to make sure the estimation is not way off.
386 static constexpr int32_t kMaxEstimatedBitrate = 100 * 1000 * 1000; // 100 Mbps
387
388 auto sampleReader = MediaSampleReaderNDK::createFromFd(mSourceFd, 0, mFileSize);
389 ASSERT_TRUE(sampleReader);
390
391 std::vector<int32_t> actualTrackBitrates = getTrackBitrates();
392 for (int trackIndex = 0; trackIndex < mTrackCount; ++trackIndex) {
393 EXPECT_EQ(sampleReader->selectTrack(trackIndex), AMEDIA_OK);
394
395 int32_t bitrate;
396 EXPECT_EQ(sampleReader->getEstimatedBitrateForTrack(trackIndex, &bitrate), AMEDIA_OK);
397 EXPECT_GT(bitrate, 0);
398 EXPECT_LT(bitrate, kMaxEstimatedBitrate);
399
400 // Note: The test asset currently used in this test is shorter than the sampling duration
401 // used to estimate the bitrate in the sample reader. So for now the estimation should be
402 // exact but if/when a longer asset is used a reasonable delta needs to be defined.
403 EXPECT_EQ(bitrate, actualTrackBitrates[trackIndex]);
404 }
405 }
406
TEST_F(MediaSampleReaderNDKTests,TestInvalidFd)407 TEST_F(MediaSampleReaderNDKTests, TestInvalidFd) {
408 std::shared_ptr<MediaSampleReader> sampleReader =
409 MediaSampleReaderNDK::createFromFd(0, 0, mFileSize);
410 ASSERT_TRUE(sampleReader == nullptr);
411
412 sampleReader = MediaSampleReaderNDK::createFromFd(-1, 0, mFileSize);
413 ASSERT_TRUE(sampleReader == nullptr);
414 }
415
TEST_F(MediaSampleReaderNDKTests,TestZeroSize)416 TEST_F(MediaSampleReaderNDKTests, TestZeroSize) {
417 std::shared_ptr<MediaSampleReader> sampleReader =
418 MediaSampleReaderNDK::createFromFd(mSourceFd, 0, 0);
419 ASSERT_TRUE(sampleReader == nullptr);
420 }
421
TEST_F(MediaSampleReaderNDKTests,TestInvalidOffset)422 TEST_F(MediaSampleReaderNDKTests, TestInvalidOffset) {
423 std::shared_ptr<MediaSampleReader> sampleReader =
424 MediaSampleReaderNDK::createFromFd(mSourceFd, mFileSize, mFileSize);
425 ASSERT_TRUE(sampleReader == nullptr);
426 }
427
428 } // namespace android
429
main(int argc,char ** argv)430 int main(int argc, char** argv) {
431 ::testing::InitGoogleTest(&argc, argv);
432 return RUN_ALL_TESTS();
433 }
434