• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "drm_hal_common"
18 
19 #include <gtest/gtest.h>
20 #include <log/log.h>
21 #include <openssl/aes.h>
22 #include <sys/mman.h>
23 #include <random>
24 
25 #include <aidlcommonsupport/NativeHandle.h>
26 #include <android/binder_manager.h>
27 #include <android/binder_process.h>
28 #include <android/sharedmem.h>
29 #include <cutils/native_handle.h>
30 
31 #include "drm_hal_clearkey_module.h"
32 #include "drm_hal_common.h"
33 
34 namespace aidl {
35 namespace android {
36 namespace hardware {
37 namespace drm {
38 namespace vts {
39 
40 namespace clearkeydrm = ::android::hardware::drm::V1_2::vts;
41 
42 using std::vector;
43 using ::aidl::android::hardware::common::Ashmem;
44 using ::aidl::android::hardware::drm::DecryptArgs;
45 using ::aidl::android::hardware::drm::DestinationBuffer;
46 using ::aidl::android::hardware::drm::EventType;
47 using ::aidl::android::hardware::drm::ICryptoPlugin;
48 using ::aidl::android::hardware::drm::IDrmPlugin;
49 using ::aidl::android::hardware::drm::KeyRequest;
50 using ::aidl::android::hardware::drm::KeyRequestType;
51 using ::aidl::android::hardware::drm::KeySetId;
52 using ::aidl::android::hardware::drm::KeyType;
53 using ::aidl::android::hardware::drm::KeyValue;
54 using ::aidl::android::hardware::drm::Mode;
55 using ::aidl::android::hardware::drm::Pattern;
56 using ::aidl::android::hardware::drm::ProvisionRequest;
57 using ::aidl::android::hardware::drm::ProvideProvisionResponseResult;
58 using ::aidl::android::hardware::drm::SecurityLevel;
59 using ::aidl::android::hardware::drm::Status;
60 using ::aidl::android::hardware::drm::SubSample;
61 using ::aidl::android::hardware::drm::Uuid;
62 
DrmErr(const::ndk::ScopedAStatus & ret)63 Status DrmErr(const ::ndk::ScopedAStatus& ret) {
64     return static_cast<Status>(ret.getServiceSpecificError());
65 }
66 
HalBaseName(const std::string & fullname)67 std::string HalBaseName(const std::string& fullname) {
68     auto idx = fullname.find('/');
69     if (idx == std::string::npos) {
70         return fullname;
71     }
72     return fullname.substr(idx + 1);
73 }
74 
75 const char* kDrmIface = "android.hardware.drm.IDrmFactory";
76 const int MAX_OPEN_SESSION_ATTEMPTS = 3;
77 
HalFullName(const std::string & iface,const std::string & basename)78 std::string HalFullName(const std::string& iface, const std::string& basename) {
79     return iface + '/' + basename;
80 }
81 
IsOk(const::ndk::ScopedAStatus & ret)82 testing::AssertionResult IsOk(const ::ndk::ScopedAStatus& ret) {
83     if (ret.isOk()) {
84         return testing::AssertionSuccess();
85     }
86     return testing::AssertionFailure() << "ex: " << ret.getExceptionCode()
87                                        << "; svc err: " << ret.getServiceSpecificError()
88                                        << "; desc: " << ret.getDescription();
89 }
90 
91 const char* kCallbackLostState = "LostState";
92 const char* kCallbackKeysChange = "KeysChange";
93 
94 drm_vts::VendorModules* DrmHalTest::gVendorModules = nullptr;
95 
96 /**
97  * DrmHalPluginListener
98  */
onEvent(EventType eventType,const vector<uint8_t> & sessionId,const vector<uint8_t> & data)99 ::ndk::ScopedAStatus DrmHalPluginListener::onEvent(
100         EventType eventType,
101         const vector<uint8_t>& sessionId,
102         const vector<uint8_t>& data) {
103     ListenerArgs args{};
104     args.eventType = eventType;
105     args.sessionId = sessionId;
106     args.data = data;
107     eventPromise.set_value(args);
108     return ::ndk::ScopedAStatus::ok();
109 }
110 
onExpirationUpdate(const vector<uint8_t> & sessionId,int64_t expiryTimeInMS)111 ::ndk::ScopedAStatus DrmHalPluginListener::onExpirationUpdate(
112         const vector<uint8_t>& sessionId,
113         int64_t expiryTimeInMS) {
114     ListenerArgs args{};
115     args.sessionId = sessionId;
116     args.expiryTimeInMS = expiryTimeInMS;
117     expirationUpdatePromise.set_value(args);
118     return ::ndk::ScopedAStatus::ok();
119 
120 }
121 
onSessionLostState(const vector<uint8_t> & sessionId)122 ::ndk::ScopedAStatus DrmHalPluginListener::onSessionLostState(const vector<uint8_t>& sessionId) {
123     ListenerArgs args{};
124     args.sessionId = sessionId;
125     sessionLostStatePromise.set_value(args);
126     return ::ndk::ScopedAStatus::ok();
127 }
128 
onKeysChange(const std::vector<uint8_t> & sessionId,const std::vector<::aidl::android::hardware::drm::KeyStatus> & keyStatusList,bool hasNewUsableKey)129 ::ndk::ScopedAStatus DrmHalPluginListener::onKeysChange(
130         const std::vector<uint8_t>& sessionId,
131         const std::vector<::aidl::android::hardware::drm::KeyStatus>& keyStatusList,
132         bool hasNewUsableKey) {
133     ListenerArgs args{};
134     args.sessionId = sessionId;
135     args.keyStatusList = keyStatusList;
136     args.hasNewUsableKey = hasNewUsableKey;
137     keysChangePromise.set_value(args);
138     return ::ndk::ScopedAStatus::ok();
139 }
140 
getListenerArgs(std::promise<ListenerArgs> & promise)141 ListenerArgs DrmHalPluginListener::getListenerArgs(std::promise<ListenerArgs>& promise) {
142     auto future = promise.get_future();
143     auto timeout = std::chrono::milliseconds(500);
144     EXPECT_EQ(future.wait_for(timeout), std::future_status::ready);
145     return future.get();
146 }
147 
getEventArgs()148 ListenerArgs DrmHalPluginListener::getEventArgs() {
149     return getListenerArgs(eventPromise);
150 }
151 
getExpirationUpdateArgs()152 ListenerArgs DrmHalPluginListener::getExpirationUpdateArgs() {
153     return getListenerArgs(expirationUpdatePromise);
154 }
155 
getSessionLostStateArgs()156 ListenerArgs DrmHalPluginListener::getSessionLostStateArgs() {
157     return getListenerArgs(sessionLostStatePromise);
158 }
159 
getKeysChangeArgs()160 ListenerArgs DrmHalPluginListener::getKeysChangeArgs() {
161     return getListenerArgs(keysChangePromise);
162 }
163 
getModuleForInstance(const std::string & instance)164 static DrmHalVTSVendorModule_V1* getModuleForInstance(const std::string& instance) {
165     if (instance.find("clearkey") != std::string::npos ||
166         instance.find("default") != std::string::npos) {
167         return new clearkeydrm::DrmHalVTSClearkeyModule();
168     }
169 
170     return static_cast<DrmHalVTSVendorModule_V1*>(
171             DrmHalTest::gVendorModules->getModuleByName(instance));
172 }
173 
174 /**
175  * DrmHalTest
176  */
177 
DrmHalTest()178 DrmHalTest::DrmHalTest() : vendorModule(getModuleForInstance(GetParamService())) {}
179 
SetUp()180 void DrmHalTest::SetUp() {
181     const ::testing::TestInfo* const test_info =
182             ::testing::UnitTest::GetInstance()->current_test_info();
183 
184     ALOGD("Running test %s.%s from (vendor) module %s", test_info->test_case_name(),
185           test_info->name(), GetParamService().c_str());
186 
187     auto svc = GetParamService();
188     const string drmInstance = HalFullName(kDrmIface, svc);
189 
190     if (!vendorModule) {
191         ASSERT_NE(drmInstance, HalFullName(kDrmIface, "widevine")) << "Widevine requires vendor module.";
192         ASSERT_NE(drmInstance, HalFullName(kDrmIface, "clearkey")) << "Clearkey requires vendor module.";
193         GTEST_SKIP() << "No vendor module installed";
194     }
195 
196     if (drmInstance.find("IDrmFactory") != std::string::npos) {
197         drmFactory = IDrmFactory::fromBinder(
198                 ::ndk::SpAIBinder(AServiceManager_waitForService(drmInstance.c_str())));
199         ASSERT_NE(drmFactory, nullptr);
200         drmPlugin = createDrmPlugin();
201         cryptoPlugin = createCryptoPlugin();
202     }
203 
204     ASSERT_EQ(HalBaseName(drmInstance), vendorModule->getServiceName());
205     contentConfigurations = vendorModule->getContentConfigurations();
206 
207     // If drm scheme not installed skip subsequent tests
208     bool result = isCryptoSchemeSupported(getAidlUUID(), SecurityLevel::SW_SECURE_CRYPTO, "cenc");
209     if (!result) {
210         if (GetParamUUID() == std::array<uint8_t, 16>()) {
211             GTEST_SKIP() << "vendor module drm scheme not supported";
212         } else {
213             FAIL() << "param scheme must be supported";
214         }
215     }
216 
217     ASSERT_NE(nullptr, drmPlugin.get())
218             << "Can't find " << vendorModule->getServiceName() << " drm aidl plugin";
219     ASSERT_NE(nullptr, cryptoPlugin.get())
220             << "Can't find " << vendorModule->getServiceName() << " crypto aidl plugin";
221 }
222 
createDrmPlugin()223 std::shared_ptr<::aidl::android::hardware::drm::IDrmPlugin> DrmHalTest::createDrmPlugin() {
224     if (drmFactory == nullptr) {
225         return nullptr;
226     }
227     std::string packageName("aidl.android.hardware.drm.test");
228     std::shared_ptr<::aidl::android::hardware::drm::IDrmPlugin> result;
229     auto ret = drmFactory->createDrmPlugin(getAidlUUID(), packageName, &result);
230     EXPECT_OK(ret) << "createDrmPlugin remote call failed";
231     return result;
232 }
233 
createCryptoPlugin()234 std::shared_ptr<::aidl::android::hardware::drm::ICryptoPlugin> DrmHalTest::createCryptoPlugin() {
235     if (drmFactory == nullptr) {
236         return nullptr;
237     }
238     vector<uint8_t> initVec;
239     std::shared_ptr<::aidl::android::hardware::drm::ICryptoPlugin> result;
240     auto ret = drmFactory->createCryptoPlugin(getAidlUUID(), initVec, &result);
241     EXPECT_OK(ret) << "createCryptoPlugin remote call failed";
242     return result;
243 }
244 
getAidlUUID()245 ::aidl::android::hardware::drm::Uuid DrmHalTest::getAidlUUID() {
246     return toAidlUuid(getUUID());
247 }
248 
getUUID()249 std::vector<uint8_t> DrmHalTest::getUUID() {
250     auto paramUUID = GetParamUUID();
251     if (paramUUID == std::array<uint8_t, 16>()) {
252         return getVendorUUID();
253     }
254     return std::vector(paramUUID.begin(), paramUUID.end());
255 }
256 
getVendorUUID()257 std::vector<uint8_t> DrmHalTest::getVendorUUID() {
258     if (vendorModule == nullptr) {
259         ALOGW("vendor module for %s not found", GetParamService().c_str());
260         return std::vector<uint8_t>(16);
261     }
262     return vendorModule->getUUID();
263 }
264 
isCryptoSchemeSupported(Uuid uuid,SecurityLevel level,std::string mime)265 bool DrmHalTest::isCryptoSchemeSupported(Uuid uuid, SecurityLevel level, std::string mime) {
266     if (drmFactory == nullptr) {
267         return false;
268     }
269     CryptoSchemes schemes{};
270     auto ret = drmFactory->getSupportedCryptoSchemes(&schemes);
271     EXPECT_OK(ret);
272     if (!ret.isOk() || !std::count(schemes.uuids.begin(), schemes.uuids.end(), uuid)) {
273         return false;
274     }
275     if (mime.empty()) {
276         EXPECT_THAT(level, AnyOf(Eq(SecurityLevel::DEFAULT), Eq(SecurityLevel::UNKNOWN)));
277         return true;
278     }
279     for (auto ct : schemes.mimeTypes) {
280         if (ct.mime != mime) {
281             continue;
282         }
283         if (level == SecurityLevel::DEFAULT || level == SecurityLevel::UNKNOWN) {
284             return true;
285         }
286         if (level <= ct.maxLevel && level >= ct.minLevel) {
287             return true;
288         }
289     }
290     return false;
291 }
292 
provision()293 void DrmHalTest::provision() {
294     std::string certificateType;
295     std::string certificateAuthority;
296     vector<uint8_t> provisionRequest;
297     std::string defaultUrl;
298     ProvisionRequest result;
299     auto ret = drmPlugin->getProvisionRequest(certificateType, certificateAuthority, &result);
300 
301     EXPECT_TXN(ret);
302     if (ret.isOk()) {
303         EXPECT_NE(result.request.size(), 0u);
304         provisionRequest = result.request;
305         defaultUrl = result.defaultUrl;
306     } else if (DrmErr(ret) == Status::ERROR_DRM_CANNOT_HANDLE) {
307         EXPECT_EQ(0u, result.request.size());
308     }
309 
310     if (provisionRequest.size() > 0) {
311         vector<uint8_t> response =
312                 vendorModule->handleProvisioningRequest(provisionRequest, defaultUrl);
313         ASSERT_NE(0u, response.size());
314 
315         ProvideProvisionResponseResult result;
316         auto ret = drmPlugin->provideProvisionResponse(response, &result);
317         EXPECT_TXN(ret);
318     }
319 }
320 
openSession(SecurityLevel level,Status * err)321 SessionId DrmHalTest::openSession(SecurityLevel level, Status* err) {
322     SessionId sessionId;
323     auto ret = drmPlugin->openSession(level, &sessionId);
324     EXPECT_TXN(ret);
325     *err = DrmErr(ret);
326     return sessionId;
327 }
328 
329 /**
330  * Helper method to open a session and verify that a non-empty
331  * session ID is returned
332  */
openSession()333 SessionId DrmHalTest::openSession() {
334     SessionId sessionId;
335 
336     int attmpt = 0;
337     while (attmpt++ < MAX_OPEN_SESSION_ATTEMPTS) {
338         auto ret = drmPlugin->openSession(SecurityLevel::DEFAULT, &sessionId);
339         if(DrmErr(ret) == Status::ERROR_DRM_NOT_PROVISIONED) {
340             provision();
341         } else {
342             EXPECT_OK(ret);
343             EXPECT_NE(0u, sessionId.size());
344             break;
345         }
346     }
347 
348     return sessionId;
349 }
350 
351 /**
352  * Helper method to close a session
353  */
closeSession(const SessionId & sessionId)354 void DrmHalTest::closeSession(const SessionId& sessionId) {
355     auto ret = drmPlugin->closeSession(sessionId);
356     EXPECT_OK(ret);
357 }
358 
getKeyRequest(const SessionId & sessionId,const DrmHalVTSVendorModule_V1::ContentConfiguration & configuration,const KeyType & type=KeyType::STREAMING)359 vector<uint8_t> DrmHalTest::getKeyRequest(
360         const SessionId& sessionId,
361         const DrmHalVTSVendorModule_V1::ContentConfiguration& configuration,
362         const KeyType& type = KeyType::STREAMING) {
363     KeyRequest result;
364     auto ret = drmPlugin->getKeyRequest(sessionId, configuration.initData, configuration.mimeType,
365                                         type, toAidlKeyedVector(configuration.optionalParameters),
366                                         &result);
367     EXPECT_OK(ret) << "Failed to get key request for configuration "
368                    << configuration.name << " for key type "
369                    << static_cast<int>(type);
370     if (type == KeyType::RELEASE) {
371         EXPECT_EQ(KeyRequestType::RELEASE, result.requestType);
372     } else {
373         EXPECT_EQ(KeyRequestType::INITIAL, result.requestType);
374     }
375     EXPECT_NE(result.request.size(), 0u) << "Expected key request size"
376                                             " to have length > 0 bytes";
377     return result.request;
378 }
379 
getContent(const KeyType & type) const380 DrmHalVTSVendorModule_V1::ContentConfiguration DrmHalTest::getContent(const KeyType& type) const {
381     for (const auto& config : contentConfigurations) {
382         if (type != KeyType::OFFLINE || config.policy.allowOffline) {
383             return config;
384         }
385     }
386     ADD_FAILURE() << "no content configurations found";
387     return {};
388 }
389 
provideKeyResponse(const SessionId & sessionId,const vector<uint8_t> & keyResponse)390 vector<uint8_t> DrmHalTest::provideKeyResponse(const SessionId& sessionId,
391                                                const vector<uint8_t>& keyResponse) {
392     KeySetId result;
393     auto ret = drmPlugin->provideKeyResponse(sessionId, keyResponse, &result);
394     EXPECT_OK(ret) << "Failure providing key response for configuration ";
395     return result.keySetId;
396 }
397 
398 /**
399  * Helper method to load keys for subsequent decrypt tests.
400  * These tests use predetermined key request/response to
401  * avoid requiring a round trip to a license server.
402  */
loadKeys(const SessionId & sessionId,const DrmHalVTSVendorModule_V1::ContentConfiguration & configuration,const KeyType & type)403 vector<uint8_t> DrmHalTest::loadKeys(
404         const SessionId& sessionId,
405         const DrmHalVTSVendorModule_V1::ContentConfiguration& configuration, const KeyType& type) {
406     vector<uint8_t> keyRequest = getKeyRequest(sessionId, configuration, type);
407 
408     /**
409      * Get key response from vendor module
410      */
411     vector<uint8_t> keyResponse =
412             vendorModule->handleKeyRequest(keyRequest, configuration.serverUrl);
413     EXPECT_NE(keyResponse.size(), 0u) << "Expected key response size "
414                                          "to have length > 0 bytes";
415 
416     return provideKeyResponse(sessionId, keyResponse);
417 }
418 
loadKeys(const SessionId & sessionId,const KeyType & type)419 vector<uint8_t> DrmHalTest::loadKeys(const SessionId& sessionId, const KeyType& type) {
420     return loadKeys(sessionId, getContent(type), type);
421 }
422 
toStdArray(const vector<uint8_t> & vec)423 std::array<uint8_t, 16> DrmHalTest::toStdArray(const vector<uint8_t>& vec) {
424     EXPECT_EQ(16u, vec.size());
425     std::array<uint8_t, 16> arr;
426     std::copy_n(vec.begin(), vec.size(), arr.begin());
427     return arr;
428 }
429 
toAidlKeyedVector(const map<string,string> & params)430 KeyedVector DrmHalTest::toAidlKeyedVector(const map<string, string>& params) {
431     std::vector<KeyValue> stdKeyedVector;
432     for (auto it = params.begin(); it != params.end(); ++it) {
433         KeyValue keyValue;
434         keyValue.key = it->first;
435         keyValue.value = it->second;
436         stdKeyedVector.push_back(keyValue);
437     }
438     return KeyedVector(stdKeyedVector);
439 }
440 
441 /**
442  * getDecryptMemory allocates memory for decryption, then sets it
443  * as a shared buffer base in the crypto hal. An output SharedBuffer
444  * is updated via reference.
445  *
446  * @param size the size of the memory segment to allocate
447  * @param the index of the memory segment which will be used
448  * to refer to it for decryption.
449  */
getDecryptMemory(size_t size,size_t index,SharedBuffer & out)450 void DrmHalTest::getDecryptMemory(size_t size, size_t index, SharedBuffer& out) {
451     out.bufferId = static_cast<int32_t>(index);
452     out.offset = 0;
453     out.size = static_cast<int64_t>(size);
454 
455     int fd = ASharedMemory_create("drmVtsSharedMemory", size);
456     EXPECT_GE(fd, 0);
457     EXPECT_EQ(size, ASharedMemory_getSize(fd));
458     auto handle = native_handle_create(1, 0);
459     handle->data[0] = fd;
460     out.handle = ::android::makeToAidl(handle);
461 
462     EXPECT_OK(cryptoPlugin->setSharedBufferBase(out));
463     native_handle_delete(handle);
464 }
465 
fillRandom(const::aidl::android::hardware::drm::SharedBuffer & buf)466 uint8_t* DrmHalTest::fillRandom(const ::aidl::android::hardware::drm::SharedBuffer& buf) {
467     std::random_device rd;
468     std::mt19937 rand(rd());
469 
470     auto fd = buf.handle.fds[0].get();
471     size_t size = buf.size;
472     uint8_t* base = static_cast<uint8_t*>(
473             mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0));
474     EXPECT_NE(MAP_FAILED, base);
475     for (size_t i = 0; i < size / sizeof(uint32_t); i++) {
476         auto p = static_cast<uint32_t*>(static_cast<void*>(base));
477         p[i] = rand();
478     }
479     return base;
480 }
481 
decrypt(Mode mode,bool isSecure,const std::array<uint8_t,16> & keyId,uint8_t * iv,const vector<SubSample> & subSamples,const Pattern & pattern,const vector<uint8_t> & key,Status expectedStatus)482 uint32_t DrmHalTest::decrypt(Mode mode, bool isSecure, const std::array<uint8_t, 16>& keyId,
483                              uint8_t* iv, const vector<SubSample>& subSamples,
484                              const Pattern& pattern, const vector<uint8_t>& key,
485                              Status expectedStatus) {
486     const size_t kSegmentIndex = 0;
487 
488     uint8_t localIv[AES_BLOCK_SIZE];
489     memcpy(localIv, iv, AES_BLOCK_SIZE);
490     vector<uint8_t> ivVec(localIv, localIv + AES_BLOCK_SIZE);
491     vector<uint8_t> keyIdVec(keyId.begin(), keyId.end());
492 
493     int64_t totalSize = 0;
494     for (size_t i = 0; i < subSamples.size(); i++) {
495         totalSize += subSamples[i].numBytesOfClearData;
496         totalSize += subSamples[i].numBytesOfEncryptedData;
497     }
498 
499     // The first totalSize bytes of shared memory is the encrypted
500     // input, the second totalSize bytes (if exists) is the decrypted output.
501     size_t factor = expectedStatus == Status::ERROR_DRM_FRAME_TOO_LARGE ? 1 : 2;
502     SharedBuffer sourceBuffer;
503     getDecryptMemory(totalSize * factor, kSegmentIndex, sourceBuffer);
504     auto base = fillRandom(sourceBuffer);
505 
506     SharedBuffer sourceRange;
507     sourceRange.bufferId = kSegmentIndex;
508     sourceRange.offset = 0;
509     sourceRange.size = totalSize;
510 
511     SharedBuffer destRange;
512     destRange.bufferId = kSegmentIndex;
513     destRange.offset = totalSize;
514     destRange.size = totalSize;
515 
516     DecryptArgs args;
517     args.secure = isSecure;
518     args.keyId = keyIdVec;
519     args.iv = ivVec;
520     args.mode = mode;
521     args.pattern = pattern;
522     args.subSamples = subSamples;
523     args.source = std::move(sourceRange);
524     args.offset = 0;
525     args.destination = std::move(destRange);
526 
527     int32_t bytesWritten = 0;
528     auto ret = cryptoPlugin->decrypt(args, &bytesWritten);
529     EXPECT_TXN(ret);
530     EXPECT_EQ(expectedStatus, DrmErr(ret)) << "Unexpected decrypt status " << ret.getMessage();
531 
532     if (bytesWritten != totalSize) {
533         return bytesWritten;
534     }
535 
536     // generate reference vector
537     vector<uint8_t> reference(totalSize);
538 
539     memcpy(localIv, iv, AES_BLOCK_SIZE);
540     switch (mode) {
541         case Mode::UNENCRYPTED:
542             memcpy(&reference[0], base, totalSize);
543             break;
544         case Mode::AES_CTR:
545             aes_ctr_decrypt(&reference[0], base, localIv, subSamples, key);
546             break;
547         case Mode::AES_CBC:
548             aes_cbc_decrypt(&reference[0], base, localIv, subSamples, key);
549             break;
550         case Mode::AES_CBC_CTS:
551             ADD_FAILURE() << "AES_CBC_CTS mode not supported";
552             break;
553     }
554 
555     // compare reference to decrypted data which is at base + total size
556     EXPECT_EQ(0, memcmp(static_cast<void*>(&reference[0]), static_cast<void*>(base + totalSize),
557                         totalSize))
558             << "decrypt data mismatch";
559     munmap(base, totalSize * factor);
560     return totalSize;
561 }
562 
563 /**
564  * Decrypt a list of clear+encrypted subsamples using the specified key
565  * in AES-CTR mode
566  */
aes_ctr_decrypt(uint8_t * dest,uint8_t * src,uint8_t * iv,const vector<SubSample> & subSamples,const vector<uint8_t> & key)567 void DrmHalTest::aes_ctr_decrypt(uint8_t* dest, uint8_t* src, uint8_t* iv,
568                                  const vector<SubSample>& subSamples, const vector<uint8_t>& key) {
569     AES_KEY decryptionKey;
570     AES_set_encrypt_key(&key[0], 128, &decryptionKey);
571 
572     size_t offset = 0;
573     unsigned int blockOffset = 0;
574     uint8_t previousEncryptedCounter[AES_BLOCK_SIZE];
575     memset(previousEncryptedCounter, 0, AES_BLOCK_SIZE);
576 
577     for (size_t i = 0; i < subSamples.size(); i++) {
578         const SubSample& subSample = subSamples[i];
579 
580         if (subSample.numBytesOfClearData > 0) {
581             memcpy(dest + offset, src + offset, subSample.numBytesOfClearData);
582             offset += subSample.numBytesOfClearData;
583         }
584 
585         if (subSample.numBytesOfEncryptedData > 0) {
586             AES_ctr128_encrypt(src + offset, dest + offset, subSample.numBytesOfEncryptedData,
587                                &decryptionKey, iv, previousEncryptedCounter, &blockOffset);
588             offset += subSample.numBytesOfEncryptedData;
589         }
590     }
591 }
592 
593 /**
594  * Decrypt a list of clear+encrypted subsamples using the specified key
595  * in AES-CBC mode
596  */
aes_cbc_decrypt(uint8_t * dest,uint8_t * src,uint8_t * iv,const vector<SubSample> & subSamples,const vector<uint8_t> & key)597 void DrmHalTest::aes_cbc_decrypt(uint8_t* dest, uint8_t* src, uint8_t* iv,
598                                  const vector<SubSample>& subSamples, const vector<uint8_t>& key) {
599     AES_KEY decryptionKey;
600     AES_set_encrypt_key(&key[0], 128, &decryptionKey);
601 
602     size_t offset = 0;
603     for (size_t i = 0; i < subSamples.size(); i++) {
604         memcpy(dest + offset, src + offset, subSamples[i].numBytesOfClearData);
605         offset += subSamples[i].numBytesOfClearData;
606 
607         AES_cbc_encrypt(src + offset, dest + offset, subSamples[i].numBytesOfEncryptedData,
608                         &decryptionKey, iv, 0 /* decrypt */);
609         offset += subSamples[i].numBytesOfEncryptedData;
610     }
611 }
612 
613 /**
614  * Helper method to test decryption with invalid keys is returned
615  */
decryptWithInvalidKeys(vector<uint8_t> & invalidResponse,vector<uint8_t> & iv,const Pattern & noPattern,const vector<SubSample> & subSamples)616 void DrmHalClearkeyTest::decryptWithInvalidKeys(vector<uint8_t>& invalidResponse,
617                                                 vector<uint8_t>& iv, const Pattern& noPattern,
618                                                 const vector<SubSample>& subSamples) {
619     DrmHalVTSVendorModule_V1::ContentConfiguration content = getContent();
620     if (content.keys.empty()) {
621         FAIL() << "no keys";
622     }
623 
624     const auto& key = content.keys[0];
625     auto sessionId = openSession();
626     KeySetId result;
627     auto ret = drmPlugin->provideKeyResponse(sessionId, invalidResponse, &result);
628 
629     EXPECT_OK(ret);
630     EXPECT_EQ(0u, result.keySetId.size());
631 
632     EXPECT_OK(cryptoPlugin->setMediaDrmSession(sessionId));
633 
634     uint32_t byteCount =
635             decrypt(Mode::AES_CTR, key.isSecure, toStdArray(key.keyId), &iv[0], subSamples,
636                     noPattern, key.clearContentKey, Status::ERROR_DRM_NO_LICENSE);
637     EXPECT_EQ(0u, byteCount);
638 
639     closeSession(sessionId);
640 }
641 
642 }  // namespace vts
643 }  // namespace drm
644 }  // namespace hardware
645 }  // namespace android
646 }  // namespace aidl
647