• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <HalInterfaces.h>
18 #include <SampleDriver.h>
19 #include <android-base/scopeguard.h>
20 #include <gtest/gtest.h>
21 
22 #include <cstdlib>
23 #include <filesystem>
24 #include <numeric>
25 #include <string>
26 #include <string_view>
27 #include <tuple>
28 #include <vector>
29 
30 #include "HalUtils.h"
31 #include "Manager.h"
32 #include "TestNeuralNetworksWrapper.h"
33 #include "TmpDirectoryUtils.h"
34 
35 using namespace android::nn;
36 namespace hardware = android::hardware;
37 using WrapperResult = test_wrapper::Result;
38 using Type = test_wrapper::Type;
39 const V1_2::Timing kBadTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
40 template <typename T>
41 using MQDescriptorSync = ::android::hardware::MQDescriptorSync<T>;
42 using android::sp;
43 
44 namespace android::hardware::neuralnetworks::V1_0 {
45 
operator <<(::std::ostream & os,V1_3::ErrorStatus errorStatus)46 ::std::ostream& operator<<(::std::ostream& os, V1_3::ErrorStatus errorStatus) {
47     return os << toString(errorStatus);
48 }
49 
50 }  // namespace android::hardware::neuralnetworks::V1_0
51 
52 namespace {
53 
54 enum class HasCalledPrepareModel { NO, WITHOUT_CACHING, WITH_CACHING };
55 
56 // Print HasCalledPrepareModel enum for better GTEST failure messages
operator <<(std::ostream & os,HasCalledPrepareModel hasCalledPrepareModel)57 std::ostream& operator<<(std::ostream& os, HasCalledPrepareModel hasCalledPrepareModel) {
58     switch (hasCalledPrepareModel) {
59         case HasCalledPrepareModel::NO:
60             return os << "NO";
61         case HasCalledPrepareModel::WITHOUT_CACHING:
62             return os << "WITHOUT_CACHING";
63         case HasCalledPrepareModel::WITH_CACHING:
64             return os << "WITH_CACHING";
65     }
66     CHECK(false) << "HasCalledPrepareModel print called with invalid code "
67                  << static_cast<int>(hasCalledPrepareModel);
68     return os;
69 }
70 
71 // Whether the driver is expected to be registered because it can pass initialization.
canDeviceBeRegistered(V1_3::ErrorStatus error,uint32_t numModelCache,uint32_t numDataCache)72 bool canDeviceBeRegistered(V1_3::ErrorStatus error, uint32_t numModelCache, uint32_t numDataCache) {
73     constexpr uint32_t maxNumCacheFiles =
74             static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES);
75     return error == V1_3::ErrorStatus::NONE && numModelCache <= maxNumCacheFiles &&
76            numDataCache <= maxNumCacheFiles;
77 }
78 
79 // Whether the driver supports caching based on the returns from getNumberOfCacheFilesNeeded.
isCachingSupported(uint32_t numModelCache,uint32_t numDataCache)80 bool isCachingSupported(uint32_t numModelCache, uint32_t numDataCache) {
81     return numModelCache != 0 || numDataCache != 0;
82 }
83 
84 // This is an IDevice for testing purposes which overrides several methods from sample driver:
85 // - supports all the operations and is faster than cpu fallback.
86 // - overrides getNumberOfCacheFilesNeeded to report according to given parameters.
87 // - overrides prepareModelFromCache_1_3 to return error status according to
88 //   mErrorStatusPrepareFromCache.
89 // - produces CachingPreparedModel on prepareModel and prepareModelFromCache_1_3.
90 //
91 // The cache entry is written by prepareModel_1_3 and is checked later by
92 // CachingDriver::prepareModelFromCache_1_3.
93 //
94 // The CachingDriver has 2 flags mHasCalledPrepareModelFromCache and mHasCalledPrepareModel
95 // to check if the correct methods are invoked by the runtime.
96 class CachingDriver : public sample_driver::SampleDriver {
97    private:
98     static constexpr size_t kCacheSize = 256;
99 
100     class CachingPreparedModel : public V1_3::IPreparedModel {
101        public:
102         CachingPreparedModel() = default;
103 
execute(const V1_0::Request &,const sp<V1_0::IExecutionCallback> &)104         hardware::Return<V1_0::ErrorStatus> execute(const V1_0::Request&,
105                                                     const sp<V1_0::IExecutionCallback>&) override {
106             return V1_0::ErrorStatus::DEVICE_UNAVAILABLE;
107         }
execute_1_2(const V1_0::Request &,V1_2::MeasureTiming,const sp<V1_2::IExecutionCallback> &)108         hardware::Return<V1_0::ErrorStatus> execute_1_2(
109                 const V1_0::Request&, V1_2::MeasureTiming,
110                 const sp<V1_2::IExecutionCallback>&) override {
111             return V1_0::ErrorStatus::DEVICE_UNAVAILABLE;
112         }
execute_1_3(const V1_3::Request &,V1_2::MeasureTiming,const V1_3::OptionalTimePoint &,const V1_3::OptionalTimeoutDuration &,const sp<V1_3::IExecutionCallback> &)113         hardware::Return<V1_3::ErrorStatus> execute_1_3(
114                 const V1_3::Request&, V1_2::MeasureTiming, const V1_3::OptionalTimePoint&,
115                 const V1_3::OptionalTimeoutDuration&,
116                 const sp<V1_3::IExecutionCallback>&) override {
117             return V1_3::ErrorStatus::DEVICE_UNAVAILABLE;
118         }
executeSynchronously(const V1_0::Request &,V1_2::MeasureTiming,executeSynchronously_cb cb)119         hardware::Return<void> executeSynchronously(const V1_0::Request&, V1_2::MeasureTiming,
120                                                     executeSynchronously_cb cb) override {
121             cb(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming);
122             return hardware::Void();
123         }
executeSynchronously_1_3(const V1_3::Request &,V1_2::MeasureTiming,const V1_3::OptionalTimePoint &,const V1_3::OptionalTimeoutDuration &,executeSynchronously_1_3_cb cb)124         hardware::Return<void> executeSynchronously_1_3(const V1_3::Request&, V1_2::MeasureTiming,
125                                                         const V1_3::OptionalTimePoint&,
126                                                         const V1_3::OptionalTimeoutDuration&,
127                                                         executeSynchronously_1_3_cb cb) override {
128             cb(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming);
129             return hardware::Void();
130         }
configureExecutionBurst(const sp<V1_2::IBurstCallback> &,const MQDescriptorSync<V1_2::FmqRequestDatum> &,const MQDescriptorSync<V1_2::FmqResultDatum> &,configureExecutionBurst_cb cb)131         hardware::Return<void> configureExecutionBurst(
132                 const sp<V1_2::IBurstCallback>&, const MQDescriptorSync<V1_2::FmqRequestDatum>&,
133                 const MQDescriptorSync<V1_2::FmqResultDatum>&,
134                 configureExecutionBurst_cb cb) override {
135             cb(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, nullptr);
136             return hardware::Void();
137         }
executeFenced(const V1_3::Request &,const hardware::hidl_vec<hardware::hidl_handle> &,V1_2::MeasureTiming,const V1_3::OptionalTimePoint &,const V1_3::OptionalTimeoutDuration &,const V1_3::OptionalTimeoutDuration &,executeFenced_cb cb)138         hardware::Return<void> executeFenced(const V1_3::Request&,
139                                              const hardware::hidl_vec<hardware::hidl_handle>&,
140                                              V1_2::MeasureTiming, const V1_3::OptionalTimePoint&,
141                                              const V1_3::OptionalTimeoutDuration&,
142                                              const V1_3::OptionalTimeoutDuration&,
143                                              executeFenced_cb cb) {
144             cb(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, hardware::hidl_handle(nullptr), nullptr);
145             return hardware::Void();
146         }
147     };
148 
149    public:
CachingDriver(std::string_view name,V1_3::ErrorStatus errorStatusGetNumCacheFiles,uint32_t numModelCache,uint32_t numDataCache,V1_3::ErrorStatus errorStatusPrepareFromCache)150     CachingDriver(std::string_view name, V1_3::ErrorStatus errorStatusGetNumCacheFiles,
151                   uint32_t numModelCache, uint32_t numDataCache,
152                   V1_3::ErrorStatus errorStatusPrepareFromCache)
153         : SampleDriver(name.data()),
154           mErrorStatusGetNumCacheFiles(errorStatusGetNumCacheFiles),
155           mNumModelCache(numModelCache),
156           mNumDataCache(numDataCache),
157           mErrorStatusPrepareFromCache(errorStatusPrepareFromCache) {
158         mModelCacheData.resize(kCacheSize);
159         std::iota(mModelCacheData.begin(), mModelCacheData.end(), 0);
160         mDataCacheData.resize(kCacheSize);
161         std::iota(mDataCacheData.begin(), mDataCacheData.end(), 1);
162     }
~CachingDriver()163     ~CachingDriver() override {}
164 
165     // Reports faster than cpu.
getCapabilities_1_3(getCapabilities_1_3_cb cb)166     hardware::Return<void> getCapabilities_1_3(getCapabilities_1_3_cb cb) override {
167         android::nn::initVLogMask();
168         const V1_0::PerformanceInfo kPerf = {.execTime = 0.1, .powerUsage = 0.1};
169         V1_3::Capabilities capabilities = {
170                 .relaxedFloat32toFloat16PerformanceScalar = kPerf,
171                 .relaxedFloat32toFloat16PerformanceTensor = kPerf,
172                 .operandPerformance = nonExtensionOperandPerformance<HalVersion::V1_3>(kPerf),
173                 .ifPerformance = kPerf,
174                 .whilePerformance = kPerf};
175         cb(V1_3::ErrorStatus::NONE, capabilities);
176         return hardware::Void();
177     }
178 
179     // Reports supporting all operations.
getSupportedOperations_1_3(const V1_3::Model & model,getSupportedOperations_1_3_cb cb)180     hardware::Return<void> getSupportedOperations_1_3(const V1_3::Model& model,
181                                                       getSupportedOperations_1_3_cb cb) override {
182         std::vector<bool> supported(model.main.operations.size(), true);
183         cb(V1_3::ErrorStatus::NONE, supported);
184         return hardware::Void();
185     }
186 
187     // Reports according to mGetNumCacheFiles.
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)188     hardware::Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override {
189         cb(convertToV1_0(mErrorStatusGetNumCacheFiles), mNumModelCache, mNumDataCache);
190         return hardware::Void();
191     }
192 
193     // Generates CachingPreparedModel.
194     // Writes the cache entry per mCacheXData and sets mHasCalledPrepareModel.
prepareModel_1_3(const V1_3::Model &,V1_1::ExecutionPreference,V1_3::Priority,const V1_3::OptionalTimePoint &,const hardware::hidl_vec<hardware::hidl_handle> & modelCacheHandle,const hardware::hidl_vec<hardware::hidl_handle> & dataCacheHandle,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & cb)195     hardware::Return<V1_3::ErrorStatus> prepareModel_1_3(
196             const V1_3::Model&, V1_1::ExecutionPreference, V1_3::Priority,
197             const V1_3::OptionalTimePoint&,
198             const hardware::hidl_vec<hardware::hidl_handle>& modelCacheHandle,
199             const hardware::hidl_vec<hardware::hidl_handle>& dataCacheHandle, const HalCacheToken&,
200             const sp<V1_3::IPreparedModelCallback>& cb) override {
201         checkNumberOfCacheHandles(modelCacheHandle.size(), dataCacheHandle.size());
202         if (modelCacheHandle.size() != 0 || dataCacheHandle.size() != 0) {
203             writeToCache(modelCacheHandle, mModelCacheData);
204             writeToCache(dataCacheHandle, mDataCacheData);
205             mHasCalledPrepareModel = HasCalledPrepareModel::WITH_CACHING;
206         } else {
207             mHasCalledPrepareModel = HasCalledPrepareModel::WITHOUT_CACHING;
208         }
209         cb->notify_1_3(V1_3::ErrorStatus::NONE, new CachingPreparedModel());
210         return V1_3::ErrorStatus::NONE;
211     }
212 
213     // Checks if the cache entry is correct, notifies error status according to
214     // mErrorStatusPrepareFromCache, sets mHasCalledPrepareModelFromCache.
prepareModelFromCache_1_3(const V1_3::OptionalTimePoint &,const hardware::hidl_vec<hardware::hidl_handle> & modelCacheHandle,const hardware::hidl_vec<hardware::hidl_handle> & dataCacheHandle,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)215     hardware::Return<V1_3::ErrorStatus> prepareModelFromCache_1_3(
216             const V1_3::OptionalTimePoint&,
217             const hardware::hidl_vec<hardware::hidl_handle>& modelCacheHandle,
218             const hardware::hidl_vec<hardware::hidl_handle>& dataCacheHandle, const HalCacheToken&,
219             const sp<V1_3::IPreparedModelCallback>& callback) override {
220         readFromCache(modelCacheHandle, mModelCacheData);
221         readFromCache(dataCacheHandle, mDataCacheData);
222         mHasCalledPrepareModelFromCache = true;
223         if (mErrorStatusPrepareFromCache == V1_3::ErrorStatus::NONE) {
224             callback->notify_1_3(mErrorStatusPrepareFromCache, new CachingPreparedModel());
225         } else {
226             callback->notify_1_3(mErrorStatusPrepareFromCache, nullptr);
227         }
228         return V1_3::ErrorStatus::NONE;
229     };
230 
hasCalledPrepareModelFromCache() const231     bool hasCalledPrepareModelFromCache() const { return mHasCalledPrepareModelFromCache; }
hasCalledPrepareModel() const232     HasCalledPrepareModel hasCalledPrepareModel() const { return mHasCalledPrepareModel; }
233 
234    private:
235     // Checks the number of cache files passed to the driver from runtime.
checkNumberOfCacheHandles(size_t modelCache,size_t dataCache)236     void checkNumberOfCacheHandles(size_t modelCache, size_t dataCache) {
237         if (isCachingSupported(mNumModelCache, mNumDataCache)) {
238             if (modelCache != 0 || dataCache != 0) {
239                 ASSERT_EQ(modelCache, mNumModelCache);
240                 ASSERT_EQ(dataCache, mNumDataCache);
241             }
242         } else {
243             ASSERT_EQ(modelCache, 0ul);
244             ASSERT_EQ(dataCache, 0ul);
245         }
246     }
247 
writeToCache(const hardware::hidl_vec<hardware::hidl_handle> & handles,const std::vector<uint8_t> & cache)248     void writeToCache(const hardware::hidl_vec<hardware::hidl_handle>& handles,
249                       const std::vector<uint8_t>& cache) {
250         for (uint32_t i = 0; i < handles.size(); ++i) {
251             ASSERT_EQ(handles[i]->numFds, 1);
252             EXPECT_EQ(write(handles[i]->data[0], cache.data(), kCacheSize),
253                       static_cast<ssize_t>(kCacheSize));
254         }
255     }
256 
readFromCache(const hardware::hidl_vec<hardware::hidl_handle> & handles,const std::vector<uint8_t> & expected)257     void readFromCache(const hardware::hidl_vec<hardware::hidl_handle>& handles,
258                        const std::vector<uint8_t>& expected) {
259         for (uint32_t i = 0; i < handles.size(); ++i) {
260             ASSERT_EQ(handles[i]->numFds, 1);
261             std::vector<uint8_t> actual(kCacheSize);
262             EXPECT_EQ(read(handles[i]->data[0], actual.data(), kCacheSize),
263                       static_cast<ssize_t>(kCacheSize));
264             EXPECT_EQ(actual, expected);
265         }
266     }
267 
268     std::vector<uint8_t> mModelCacheData;
269     std::vector<uint8_t> mDataCacheData;
270 
271     const V1_3::ErrorStatus mErrorStatusGetNumCacheFiles;
272     const uint32_t mNumModelCache;
273     const uint32_t mNumDataCache;
274     const V1_3::ErrorStatus mErrorStatusPrepareFromCache;
275 
276     bool mHasCalledPrepareModelFromCache = false;
277     HasCalledPrepareModel mHasCalledPrepareModel = HasCalledPrepareModel::NO;
278 };
279 
CreateBroadcastAddModel(test_wrapper::Model * model)280 void CreateBroadcastAddModel(test_wrapper::Model* model) {
281     test_wrapper::OperandType matrixType(Type::TENSOR_FLOAT32, {2, 2});
282     test_wrapper::OperandType vectorType(Type::TENSOR_FLOAT32, {2});
283     test_wrapper::OperandType scalarType(Type::INT32, {});
284     int32_t activation(ANEURALNETWORKS_FUSED_NONE);
285     auto a = model->addOperand(&matrixType);
286     auto b = model->addOperand(&vectorType);
287     auto c = model->addOperand(&matrixType);
288     auto d = model->addOperand(&scalarType);
289     model->setOperandValue(d, &activation, sizeof(activation));
290     model->addOperation(ANEURALNETWORKS_ADD, {a, b, d}, {c});
291     model->identifyInputsAndOutputs({a, b}, {c});
292     ASSERT_TRUE(model->isValid());
293     ASSERT_EQ(model->finish(), WrapperResult::NO_ERROR);
294 }
295 
getDeviceWithName(std::string_view deviceName,const ANeuralNetworksDevice ** outputDevice)296 void getDeviceWithName(std::string_view deviceName, const ANeuralNetworksDevice** outputDevice) {
297     uint32_t numDevices = 0;
298     ASSERT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
299     EXPECT_GE(numDevices, (uint32_t)1);
300 
301     int numMatchingDevices = 0;
302     for (uint32_t i = 0; i < numDevices; i++) {
303         ANeuralNetworksDevice* device = nullptr;
304         ASSERT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
305 
306         const char* buffer = nullptr;
307         ASSERT_EQ(ANeuralNetworksDevice_getName(device, &buffer), ANEURALNETWORKS_NO_ERROR);
308         if (deviceName == buffer) {
309             *outputDevice = device;
310             numMatchingDevices++;
311         }
312     }
313 
314     EXPECT_LE(numMatchingDevices, 1);
315 }
316 
317 // Test device registration with a driver parameterized with
318 // - ErrorStatus returning from getNumberOfCacheFilesNeeded
319 // - Number of model cache files returning from getNumberOfCacheFilesNeeded
320 // - Number of data cache files returning from getNumberOfCacheFilesNeeded
321 using DeviceRegistrationTestParam = std::tuple<V1_3::ErrorStatus, uint32_t, uint32_t>;
322 
323 class DeviceRegistrationTest : public ::testing::TestWithParam<DeviceRegistrationTestParam> {
324    protected:
325     static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching";
326     const V1_3::ErrorStatus kErrorStatusGetNumCacheFiles = std::get<0>(GetParam());
327     const uint32_t kNumModelCache = std::get<1>(GetParam());
328     const uint32_t kNumDataCache = std::get<2>(GetParam());
329     const sp<CachingDriver> kDriver =
330             new CachingDriver(kDeviceName, kErrorStatusGetNumCacheFiles, kNumModelCache,
331                               kNumDataCache, V1_3::ErrorStatus::NONE);
332 };
333 
TEST_P(DeviceRegistrationTest,CachingFailure)334 TEST_P(DeviceRegistrationTest, CachingFailure) {
335     if (DeviceManager::get()->getUseCpuOnly()) {
336         return;
337     }
338 
339     DeviceManager::get()->forTest_registerDevice(makeSharedDevice(kDeviceName.data(), kDriver));
340     const auto cleanup = android::base::make_scope_guard(
341             [] { DeviceManager::get()->forTest_reInitializeDeviceList(); });
342 
343     // get device
344     const ANeuralNetworksDevice* device = nullptr;
345     getDeviceWithName(kDeviceName, &device);
346 
347     // check if device registeration matches expectations
348     const bool isDeviceRegistered = (device != nullptr);
349     const bool expectDeviceToBeRegistered =
350             canDeviceBeRegistered(kErrorStatusGetNumCacheFiles, kNumModelCache, kNumDataCache);
351     ASSERT_EQ(isDeviceRegistered, expectDeviceToBeRegistered);
352 }
353 
354 // Test model compilation with a driver parameterized with
355 // - Number of model cache files returning from getNumberOfCacheFilesNeeded
356 // - Number of data cache files returning from getNumberOfCacheFilesNeeded
357 // - ErrorStatus returning from prepareModelFromCache_1_3
358 using CompilationCachingTestParam = std::tuple<uint32_t, uint32_t, V1_3::ErrorStatus>;
359 
360 class CompilationCachingTest : public ::testing::TestWithParam<CompilationCachingTestParam> {
361    protected:
SetUp()362     virtual void SetUp() override {
363         char cacheDirTemp[] = NN_TMP_DIR "/AVeryLongDirectoryNameForTestCompilationCachingXXXXXX";
364         char* cacheDir = mkdtemp(cacheDirTemp);
365         ASSERT_NE(cacheDir, nullptr);
366         mCacheDir = cacheDir;
367         CreateBroadcastAddModel(&mModel);
368     }
369 
TearDown()370     virtual void TearDown() override {
371         if (!::testing::Test::HasFailure()) {
372             std::filesystem::remove_all(mCacheDir);
373         }
374     }
375 
compileModel(const sp<CachingDriver> & driver,bool withToken)376     void compileModel(const sp<CachingDriver>& driver, bool withToken) {
377         DeviceManager::get()->forTest_registerDevice(makeSharedDevice(kDeviceName.data(), driver));
378         const auto cleanup = android::base::make_scope_guard(
379                 [] { DeviceManager::get()->forTest_reInitializeDeviceList(); });
380 
381         // Get a handle to the single driver device matching kDeviceName.
382         const ANeuralNetworksDevice* device = nullptr;
383         getDeviceWithName(kDeviceName, &device);
384         ASSERT_NE(device, nullptr);
385 
386         // Compile the model with the device.
387         ANeuralNetworksCompilation* compilation = nullptr;
388         ASSERT_EQ(ANeuralNetworksCompilation_createForDevices(mModel.getHandle(), &device, 1,
389                                                               &compilation),
390                   ANEURALNETWORKS_NO_ERROR);
391         if (withToken) {
392             ASSERT_EQ(ANeuralNetworksCompilation_setCaching(compilation, mCacheDir.c_str(),
393                                                             kToken.data()),
394                       ANEURALNETWORKS_NO_ERROR);
395         }
396         ASSERT_EQ(ANeuralNetworksCompilation_finish(compilation), ANEURALNETWORKS_NO_ERROR);
397 
398         // close memory
399         ANeuralNetworksCompilation_free(compilation);
400     }
401 
createCache()402     void createCache() {
403         sp<CachingDriver> driver =
404                 new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache,
405                                   kNumDataCache, V1_3::ErrorStatus::NONE);
406         compileModel(driver, /*withToken=*/true);
407     }
408 
409     static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching";
410     const uint32_t kNumModelCache = std::get<0>(GetParam());
411     const uint32_t kNumDataCache = std::get<1>(GetParam());
412     const V1_3::ErrorStatus kErrorStatusPrepareFromCache = std::get<2>(GetParam());
413     const bool kIsCachingSupported = isCachingSupported(kNumModelCache, kNumDataCache);
414     test_wrapper::Model mModel;
415     std::string mCacheDir;
416     const HalCacheToken kToken{};
417 };
418 
TEST_P(CompilationCachingTest,TokenProvidedAndCacheNotExist)419 TEST_P(CompilationCachingTest, TokenProvidedAndCacheNotExist) {
420     if (DeviceManager::get()->getUseCpuOnly()) {
421         return;
422     }
423     sp<CachingDriver> driver =
424             new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache, kNumDataCache,
425                               kErrorStatusPrepareFromCache);
426     compileModel(driver, /*withToken=*/true);
427 
428     // When cache file does not exist, the runtime should never call prepareModelFromCache_1_3.
429     EXPECT_FALSE(driver->hasCalledPrepareModelFromCache());
430 
431     // The runtime should call prepareModel_1_3. It should request caching iff caching supported.
432     EXPECT_EQ(driver->hasCalledPrepareModel(), kIsCachingSupported
433                                                        ? HasCalledPrepareModel::WITH_CACHING
434                                                        : HasCalledPrepareModel::WITHOUT_CACHING);
435 }
436 
TEST_P(CompilationCachingTest,TokenProvidedAndCacheExist)437 TEST_P(CompilationCachingTest, TokenProvidedAndCacheExist) {
438     if (DeviceManager::get()->getUseCpuOnly()) {
439         return;
440     }
441     createCache();
442     sp<CachingDriver> driver =
443             new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache, kNumDataCache,
444                               kErrorStatusPrepareFromCache);
445     compileModel(driver, /*withToken=*/true);
446 
447     // When cache files exist, the runtime should call prepareModelFromCache_1_3 iff caching
448     // supported.
449     EXPECT_EQ(driver->hasCalledPrepareModelFromCache(), kIsCachingSupported);
450 
451     HasCalledPrepareModel expectHasCalledPrepareModel;
452     if (kIsCachingSupported) {
453         if (kErrorStatusPrepareFromCache == V1_3::ErrorStatus::NONE) {
454             // The runtime should not call prepareModel_1_3 iff caching supported and
455             // prepareModelFromCache_1_3 succeeds.
456             expectHasCalledPrepareModel = HasCalledPrepareModel::NO;
457         } else {
458             // The runtime should call prepareModel_1_3 and request caching iff caching supported
459             // but prepareModelFromCache_1_3 fails.
460             expectHasCalledPrepareModel = HasCalledPrepareModel::WITH_CACHING;
461         }
462     } else {
463         // The runtime should call prepareModel_1_3 without caching iff caching not supported.
464         expectHasCalledPrepareModel = HasCalledPrepareModel::WITHOUT_CACHING;
465     }
466     EXPECT_EQ(driver->hasCalledPrepareModel(), expectHasCalledPrepareModel);
467 }
468 
TEST_P(CompilationCachingTest,TokenNotProvided)469 TEST_P(CompilationCachingTest, TokenNotProvided) {
470     if (DeviceManager::get()->getUseCpuOnly()) {
471         return;
472     }
473     sp<CachingDriver> driver =
474             new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache, kNumDataCache,
475                               kErrorStatusPrepareFromCache);
476     compileModel(driver, /*withToken=*/false);
477 
478     // When no NDK token is provided by the client, the runtime should never call
479     // prepareModelFromCache_1_3 or request caching with prepareModel_1_3.
480     EXPECT_FALSE(driver->hasCalledPrepareModelFromCache());
481     EXPECT_EQ(driver->hasCalledPrepareModel(), HasCalledPrepareModel::WITHOUT_CACHING);
482 }
483 
484 static const auto kErrorStatusGetNumCacheFilesChoices =
485         testing::Values(V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::DEVICE_UNAVAILABLE);
486 static const auto kNumCacheChoices =
487         testing::Values(0ul, 1ul, static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES),
488                         static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES) + 1);
489 static const auto kNumValidCacheChoices =
490         testing::Values(0ul, 1ul, static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES));
491 static const auto kErrorStatusPrepareFromCacheChoices =
492         testing::Values(V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::GENERAL_FAILURE,
493                         V1_3::ErrorStatus::DEVICE_UNAVAILABLE, V1_3::ErrorStatus::INVALID_ARGUMENT);
494 
495 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, DeviceRegistrationTest,
496                          testing::Combine(kErrorStatusGetNumCacheFilesChoices, kNumCacheChoices,
497                                           kNumCacheChoices));
498 
499 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, CompilationCachingTest,
500                          testing::Combine(kNumValidCacheChoices, kNumValidCacheChoices,
501                                           kErrorStatusPrepareFromCacheChoices));
502 
503 }  // namespace
504