1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <HalInterfaces.h>
18 #include <SampleDriver.h>
19 #include <android-base/scopeguard.h>
20 #include <gtest/gtest.h>
21
22 #include <cstdlib>
23 #include <filesystem>
24 #include <numeric>
25 #include <string>
26 #include <string_view>
27 #include <tuple>
28 #include <vector>
29
30 #include "HalUtils.h"
31 #include "Manager.h"
32 #include "TestNeuralNetworksWrapper.h"
33 #include "TmpDirectoryUtils.h"
34
35 using namespace android::nn;
36 namespace hardware = android::hardware;
37 using WrapperResult = test_wrapper::Result;
38 using Type = test_wrapper::Type;
39 const V1_2::Timing kBadTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
40 template <typename T>
41 using MQDescriptorSync = ::android::hardware::MQDescriptorSync<T>;
42 using android::sp;
43
44 namespace android::hardware::neuralnetworks::V1_0 {
45
operator <<(::std::ostream & os,V1_3::ErrorStatus errorStatus)46 ::std::ostream& operator<<(::std::ostream& os, V1_3::ErrorStatus errorStatus) {
47 return os << toString(errorStatus);
48 }
49
50 } // namespace android::hardware::neuralnetworks::V1_0
51
52 namespace {
53
54 enum class HasCalledPrepareModel { NO, WITHOUT_CACHING, WITH_CACHING };
55
56 // Print HasCalledPrepareModel enum for better GTEST failure messages
operator <<(std::ostream & os,HasCalledPrepareModel hasCalledPrepareModel)57 std::ostream& operator<<(std::ostream& os, HasCalledPrepareModel hasCalledPrepareModel) {
58 switch (hasCalledPrepareModel) {
59 case HasCalledPrepareModel::NO:
60 return os << "NO";
61 case HasCalledPrepareModel::WITHOUT_CACHING:
62 return os << "WITHOUT_CACHING";
63 case HasCalledPrepareModel::WITH_CACHING:
64 return os << "WITH_CACHING";
65 }
66 CHECK(false) << "HasCalledPrepareModel print called with invalid code "
67 << static_cast<int>(hasCalledPrepareModel);
68 return os;
69 }
70
71 // Whether the driver is expected to be registered because it can pass initialization.
canDeviceBeRegistered(V1_3::ErrorStatus error,uint32_t numModelCache,uint32_t numDataCache)72 bool canDeviceBeRegistered(V1_3::ErrorStatus error, uint32_t numModelCache, uint32_t numDataCache) {
73 constexpr uint32_t maxNumCacheFiles =
74 static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES);
75 return error == V1_3::ErrorStatus::NONE && numModelCache <= maxNumCacheFiles &&
76 numDataCache <= maxNumCacheFiles;
77 }
78
79 // Whether the driver supports caching based on the returns from getNumberOfCacheFilesNeeded.
isCachingSupported(uint32_t numModelCache,uint32_t numDataCache)80 bool isCachingSupported(uint32_t numModelCache, uint32_t numDataCache) {
81 return numModelCache != 0 || numDataCache != 0;
82 }
83
84 // This is an IDevice for testing purposes which overrides several methods from sample driver:
85 // - supports all the operations and is faster than cpu fallback.
86 // - overrides getNumberOfCacheFilesNeeded to report according to given parameters.
87 // - overrides prepareModelFromCache_1_3 to return error status according to
88 // mErrorStatusPrepareFromCache.
89 // - produces CachingPreparedModel on prepareModel and prepareModelFromCache_1_3.
90 //
91 // The cache entry is written by prepareModel_1_3 and is checked later by
92 // CachingDriver::prepareModelFromCache_1_3.
93 //
94 // The CachingDriver has 2 flags mHasCalledPrepareModelFromCache and mHasCalledPrepareModel
95 // to check if the correct methods are invoked by the runtime.
96 class CachingDriver : public sample_driver::SampleDriver {
97 private:
98 static constexpr size_t kCacheSize = 256;
99
100 class CachingPreparedModel : public V1_3::IPreparedModel {
101 public:
102 CachingPreparedModel() = default;
103
execute(const V1_0::Request &,const sp<V1_0::IExecutionCallback> &)104 hardware::Return<V1_0::ErrorStatus> execute(const V1_0::Request&,
105 const sp<V1_0::IExecutionCallback>&) override {
106 return V1_0::ErrorStatus::DEVICE_UNAVAILABLE;
107 }
execute_1_2(const V1_0::Request &,V1_2::MeasureTiming,const sp<V1_2::IExecutionCallback> &)108 hardware::Return<V1_0::ErrorStatus> execute_1_2(
109 const V1_0::Request&, V1_2::MeasureTiming,
110 const sp<V1_2::IExecutionCallback>&) override {
111 return V1_0::ErrorStatus::DEVICE_UNAVAILABLE;
112 }
execute_1_3(const V1_3::Request &,V1_2::MeasureTiming,const V1_3::OptionalTimePoint &,const V1_3::OptionalTimeoutDuration &,const sp<V1_3::IExecutionCallback> &)113 hardware::Return<V1_3::ErrorStatus> execute_1_3(
114 const V1_3::Request&, V1_2::MeasureTiming, const V1_3::OptionalTimePoint&,
115 const V1_3::OptionalTimeoutDuration&,
116 const sp<V1_3::IExecutionCallback>&) override {
117 return V1_3::ErrorStatus::DEVICE_UNAVAILABLE;
118 }
executeSynchronously(const V1_0::Request &,V1_2::MeasureTiming,executeSynchronously_cb cb)119 hardware::Return<void> executeSynchronously(const V1_0::Request&, V1_2::MeasureTiming,
120 executeSynchronously_cb cb) override {
121 cb(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming);
122 return hardware::Void();
123 }
executeSynchronously_1_3(const V1_3::Request &,V1_2::MeasureTiming,const V1_3::OptionalTimePoint &,const V1_3::OptionalTimeoutDuration &,executeSynchronously_1_3_cb cb)124 hardware::Return<void> executeSynchronously_1_3(const V1_3::Request&, V1_2::MeasureTiming,
125 const V1_3::OptionalTimePoint&,
126 const V1_3::OptionalTimeoutDuration&,
127 executeSynchronously_1_3_cb cb) override {
128 cb(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming);
129 return hardware::Void();
130 }
configureExecutionBurst(const sp<V1_2::IBurstCallback> &,const MQDescriptorSync<V1_2::FmqRequestDatum> &,const MQDescriptorSync<V1_2::FmqResultDatum> &,configureExecutionBurst_cb cb)131 hardware::Return<void> configureExecutionBurst(
132 const sp<V1_2::IBurstCallback>&, const MQDescriptorSync<V1_2::FmqRequestDatum>&,
133 const MQDescriptorSync<V1_2::FmqResultDatum>&,
134 configureExecutionBurst_cb cb) override {
135 cb(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, nullptr);
136 return hardware::Void();
137 }
executeFenced(const V1_3::Request &,const hardware::hidl_vec<hardware::hidl_handle> &,V1_2::MeasureTiming,const V1_3::OptionalTimePoint &,const V1_3::OptionalTimeoutDuration &,const V1_3::OptionalTimeoutDuration &,executeFenced_cb cb)138 hardware::Return<void> executeFenced(const V1_3::Request&,
139 const hardware::hidl_vec<hardware::hidl_handle>&,
140 V1_2::MeasureTiming, const V1_3::OptionalTimePoint&,
141 const V1_3::OptionalTimeoutDuration&,
142 const V1_3::OptionalTimeoutDuration&,
143 executeFenced_cb cb) {
144 cb(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, hardware::hidl_handle(nullptr), nullptr);
145 return hardware::Void();
146 }
147 };
148
149 public:
CachingDriver(std::string_view name,V1_3::ErrorStatus errorStatusGetNumCacheFiles,uint32_t numModelCache,uint32_t numDataCache,V1_3::ErrorStatus errorStatusPrepareFromCache)150 CachingDriver(std::string_view name, V1_3::ErrorStatus errorStatusGetNumCacheFiles,
151 uint32_t numModelCache, uint32_t numDataCache,
152 V1_3::ErrorStatus errorStatusPrepareFromCache)
153 : SampleDriver(name.data()),
154 mErrorStatusGetNumCacheFiles(errorStatusGetNumCacheFiles),
155 mNumModelCache(numModelCache),
156 mNumDataCache(numDataCache),
157 mErrorStatusPrepareFromCache(errorStatusPrepareFromCache) {
158 mModelCacheData.resize(kCacheSize);
159 std::iota(mModelCacheData.begin(), mModelCacheData.end(), 0);
160 mDataCacheData.resize(kCacheSize);
161 std::iota(mDataCacheData.begin(), mDataCacheData.end(), 1);
162 }
~CachingDriver()163 ~CachingDriver() override {}
164
165 // Reports faster than cpu.
getCapabilities_1_3(getCapabilities_1_3_cb cb)166 hardware::Return<void> getCapabilities_1_3(getCapabilities_1_3_cb cb) override {
167 android::nn::initVLogMask();
168 const V1_0::PerformanceInfo kPerf = {.execTime = 0.1, .powerUsage = 0.1};
169 V1_3::Capabilities capabilities = {
170 .relaxedFloat32toFloat16PerformanceScalar = kPerf,
171 .relaxedFloat32toFloat16PerformanceTensor = kPerf,
172 .operandPerformance = nonExtensionOperandPerformance<HalVersion::V1_3>(kPerf),
173 .ifPerformance = kPerf,
174 .whilePerformance = kPerf};
175 cb(V1_3::ErrorStatus::NONE, capabilities);
176 return hardware::Void();
177 }
178
179 // Reports supporting all operations.
getSupportedOperations_1_3(const V1_3::Model & model,getSupportedOperations_1_3_cb cb)180 hardware::Return<void> getSupportedOperations_1_3(const V1_3::Model& model,
181 getSupportedOperations_1_3_cb cb) override {
182 std::vector<bool> supported(model.main.operations.size(), true);
183 cb(V1_3::ErrorStatus::NONE, supported);
184 return hardware::Void();
185 }
186
187 // Reports according to mGetNumCacheFiles.
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)188 hardware::Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override {
189 cb(convertToV1_0(mErrorStatusGetNumCacheFiles), mNumModelCache, mNumDataCache);
190 return hardware::Void();
191 }
192
193 // Generates CachingPreparedModel.
194 // Writes the cache entry per mCacheXData and sets mHasCalledPrepareModel.
prepareModel_1_3(const V1_3::Model &,V1_1::ExecutionPreference,V1_3::Priority,const V1_3::OptionalTimePoint &,const hardware::hidl_vec<hardware::hidl_handle> & modelCacheHandle,const hardware::hidl_vec<hardware::hidl_handle> & dataCacheHandle,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & cb)195 hardware::Return<V1_3::ErrorStatus> prepareModel_1_3(
196 const V1_3::Model&, V1_1::ExecutionPreference, V1_3::Priority,
197 const V1_3::OptionalTimePoint&,
198 const hardware::hidl_vec<hardware::hidl_handle>& modelCacheHandle,
199 const hardware::hidl_vec<hardware::hidl_handle>& dataCacheHandle, const HalCacheToken&,
200 const sp<V1_3::IPreparedModelCallback>& cb) override {
201 checkNumberOfCacheHandles(modelCacheHandle.size(), dataCacheHandle.size());
202 if (modelCacheHandle.size() != 0 || dataCacheHandle.size() != 0) {
203 writeToCache(modelCacheHandle, mModelCacheData);
204 writeToCache(dataCacheHandle, mDataCacheData);
205 mHasCalledPrepareModel = HasCalledPrepareModel::WITH_CACHING;
206 } else {
207 mHasCalledPrepareModel = HasCalledPrepareModel::WITHOUT_CACHING;
208 }
209 cb->notify_1_3(V1_3::ErrorStatus::NONE, new CachingPreparedModel());
210 return V1_3::ErrorStatus::NONE;
211 }
212
213 // Checks if the cache entry is correct, notifies error status according to
214 // mErrorStatusPrepareFromCache, sets mHasCalledPrepareModelFromCache.
prepareModelFromCache_1_3(const V1_3::OptionalTimePoint &,const hardware::hidl_vec<hardware::hidl_handle> & modelCacheHandle,const hardware::hidl_vec<hardware::hidl_handle> & dataCacheHandle,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)215 hardware::Return<V1_3::ErrorStatus> prepareModelFromCache_1_3(
216 const V1_3::OptionalTimePoint&,
217 const hardware::hidl_vec<hardware::hidl_handle>& modelCacheHandle,
218 const hardware::hidl_vec<hardware::hidl_handle>& dataCacheHandle, const HalCacheToken&,
219 const sp<V1_3::IPreparedModelCallback>& callback) override {
220 readFromCache(modelCacheHandle, mModelCacheData);
221 readFromCache(dataCacheHandle, mDataCacheData);
222 mHasCalledPrepareModelFromCache = true;
223 if (mErrorStatusPrepareFromCache == V1_3::ErrorStatus::NONE) {
224 callback->notify_1_3(mErrorStatusPrepareFromCache, new CachingPreparedModel());
225 } else {
226 callback->notify_1_3(mErrorStatusPrepareFromCache, nullptr);
227 }
228 return V1_3::ErrorStatus::NONE;
229 };
230
hasCalledPrepareModelFromCache() const231 bool hasCalledPrepareModelFromCache() const { return mHasCalledPrepareModelFromCache; }
hasCalledPrepareModel() const232 HasCalledPrepareModel hasCalledPrepareModel() const { return mHasCalledPrepareModel; }
233
234 private:
235 // Checks the number of cache files passed to the driver from runtime.
checkNumberOfCacheHandles(size_t modelCache,size_t dataCache)236 void checkNumberOfCacheHandles(size_t modelCache, size_t dataCache) {
237 if (isCachingSupported(mNumModelCache, mNumDataCache)) {
238 if (modelCache != 0 || dataCache != 0) {
239 ASSERT_EQ(modelCache, mNumModelCache);
240 ASSERT_EQ(dataCache, mNumDataCache);
241 }
242 } else {
243 ASSERT_EQ(modelCache, 0ul);
244 ASSERT_EQ(dataCache, 0ul);
245 }
246 }
247
writeToCache(const hardware::hidl_vec<hardware::hidl_handle> & handles,const std::vector<uint8_t> & cache)248 void writeToCache(const hardware::hidl_vec<hardware::hidl_handle>& handles,
249 const std::vector<uint8_t>& cache) {
250 for (uint32_t i = 0; i < handles.size(); ++i) {
251 ASSERT_EQ(handles[i]->numFds, 1);
252 EXPECT_EQ(write(handles[i]->data[0], cache.data(), kCacheSize),
253 static_cast<ssize_t>(kCacheSize));
254 }
255 }
256
readFromCache(const hardware::hidl_vec<hardware::hidl_handle> & handles,const std::vector<uint8_t> & expected)257 void readFromCache(const hardware::hidl_vec<hardware::hidl_handle>& handles,
258 const std::vector<uint8_t>& expected) {
259 for (uint32_t i = 0; i < handles.size(); ++i) {
260 ASSERT_EQ(handles[i]->numFds, 1);
261 std::vector<uint8_t> actual(kCacheSize);
262 EXPECT_EQ(read(handles[i]->data[0], actual.data(), kCacheSize),
263 static_cast<ssize_t>(kCacheSize));
264 EXPECT_EQ(actual, expected);
265 }
266 }
267
268 std::vector<uint8_t> mModelCacheData;
269 std::vector<uint8_t> mDataCacheData;
270
271 const V1_3::ErrorStatus mErrorStatusGetNumCacheFiles;
272 const uint32_t mNumModelCache;
273 const uint32_t mNumDataCache;
274 const V1_3::ErrorStatus mErrorStatusPrepareFromCache;
275
276 bool mHasCalledPrepareModelFromCache = false;
277 HasCalledPrepareModel mHasCalledPrepareModel = HasCalledPrepareModel::NO;
278 };
279
CreateBroadcastAddModel(test_wrapper::Model * model)280 void CreateBroadcastAddModel(test_wrapper::Model* model) {
281 test_wrapper::OperandType matrixType(Type::TENSOR_FLOAT32, {2, 2});
282 test_wrapper::OperandType vectorType(Type::TENSOR_FLOAT32, {2});
283 test_wrapper::OperandType scalarType(Type::INT32, {});
284 int32_t activation(ANEURALNETWORKS_FUSED_NONE);
285 auto a = model->addOperand(&matrixType);
286 auto b = model->addOperand(&vectorType);
287 auto c = model->addOperand(&matrixType);
288 auto d = model->addOperand(&scalarType);
289 model->setOperandValue(d, &activation, sizeof(activation));
290 model->addOperation(ANEURALNETWORKS_ADD, {a, b, d}, {c});
291 model->identifyInputsAndOutputs({a, b}, {c});
292 ASSERT_TRUE(model->isValid());
293 ASSERT_EQ(model->finish(), WrapperResult::NO_ERROR);
294 }
295
getDeviceWithName(std::string_view deviceName,const ANeuralNetworksDevice ** outputDevice)296 void getDeviceWithName(std::string_view deviceName, const ANeuralNetworksDevice** outputDevice) {
297 uint32_t numDevices = 0;
298 ASSERT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
299 EXPECT_GE(numDevices, (uint32_t)1);
300
301 int numMatchingDevices = 0;
302 for (uint32_t i = 0; i < numDevices; i++) {
303 ANeuralNetworksDevice* device = nullptr;
304 ASSERT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
305
306 const char* buffer = nullptr;
307 ASSERT_EQ(ANeuralNetworksDevice_getName(device, &buffer), ANEURALNETWORKS_NO_ERROR);
308 if (deviceName == buffer) {
309 *outputDevice = device;
310 numMatchingDevices++;
311 }
312 }
313
314 EXPECT_LE(numMatchingDevices, 1);
315 }
316
317 // Test device registration with a driver parameterized with
318 // - ErrorStatus returning from getNumberOfCacheFilesNeeded
319 // - Number of model cache files returning from getNumberOfCacheFilesNeeded
320 // - Number of data cache files returning from getNumberOfCacheFilesNeeded
321 using DeviceRegistrationTestParam = std::tuple<V1_3::ErrorStatus, uint32_t, uint32_t>;
322
323 class DeviceRegistrationTest : public ::testing::TestWithParam<DeviceRegistrationTestParam> {
324 protected:
325 static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching";
326 const V1_3::ErrorStatus kErrorStatusGetNumCacheFiles = std::get<0>(GetParam());
327 const uint32_t kNumModelCache = std::get<1>(GetParam());
328 const uint32_t kNumDataCache = std::get<2>(GetParam());
329 const sp<CachingDriver> kDriver =
330 new CachingDriver(kDeviceName, kErrorStatusGetNumCacheFiles, kNumModelCache,
331 kNumDataCache, V1_3::ErrorStatus::NONE);
332 };
333
TEST_P(DeviceRegistrationTest,CachingFailure)334 TEST_P(DeviceRegistrationTest, CachingFailure) {
335 if (DeviceManager::get()->getUseCpuOnly()) {
336 return;
337 }
338
339 DeviceManager::get()->forTest_registerDevice(makeSharedDevice(kDeviceName.data(), kDriver));
340 const auto cleanup = android::base::make_scope_guard(
341 [] { DeviceManager::get()->forTest_reInitializeDeviceList(); });
342
343 // get device
344 const ANeuralNetworksDevice* device = nullptr;
345 getDeviceWithName(kDeviceName, &device);
346
347 // check if device registeration matches expectations
348 const bool isDeviceRegistered = (device != nullptr);
349 const bool expectDeviceToBeRegistered =
350 canDeviceBeRegistered(kErrorStatusGetNumCacheFiles, kNumModelCache, kNumDataCache);
351 ASSERT_EQ(isDeviceRegistered, expectDeviceToBeRegistered);
352 }
353
354 // Test model compilation with a driver parameterized with
355 // - Number of model cache files returning from getNumberOfCacheFilesNeeded
356 // - Number of data cache files returning from getNumberOfCacheFilesNeeded
357 // - ErrorStatus returning from prepareModelFromCache_1_3
358 using CompilationCachingTestParam = std::tuple<uint32_t, uint32_t, V1_3::ErrorStatus>;
359
360 class CompilationCachingTest : public ::testing::TestWithParam<CompilationCachingTestParam> {
361 protected:
SetUp()362 virtual void SetUp() override {
363 char cacheDirTemp[] = NN_TMP_DIR "/AVeryLongDirectoryNameForTestCompilationCachingXXXXXX";
364 char* cacheDir = mkdtemp(cacheDirTemp);
365 ASSERT_NE(cacheDir, nullptr);
366 mCacheDir = cacheDir;
367 CreateBroadcastAddModel(&mModel);
368 }
369
TearDown()370 virtual void TearDown() override {
371 if (!::testing::Test::HasFailure()) {
372 std::filesystem::remove_all(mCacheDir);
373 }
374 }
375
compileModel(const sp<CachingDriver> & driver,bool withToken)376 void compileModel(const sp<CachingDriver>& driver, bool withToken) {
377 DeviceManager::get()->forTest_registerDevice(makeSharedDevice(kDeviceName.data(), driver));
378 const auto cleanup = android::base::make_scope_guard(
379 [] { DeviceManager::get()->forTest_reInitializeDeviceList(); });
380
381 // Get a handle to the single driver device matching kDeviceName.
382 const ANeuralNetworksDevice* device = nullptr;
383 getDeviceWithName(kDeviceName, &device);
384 ASSERT_NE(device, nullptr);
385
386 // Compile the model with the device.
387 ANeuralNetworksCompilation* compilation = nullptr;
388 ASSERT_EQ(ANeuralNetworksCompilation_createForDevices(mModel.getHandle(), &device, 1,
389 &compilation),
390 ANEURALNETWORKS_NO_ERROR);
391 if (withToken) {
392 ASSERT_EQ(ANeuralNetworksCompilation_setCaching(compilation, mCacheDir.c_str(),
393 kToken.data()),
394 ANEURALNETWORKS_NO_ERROR);
395 }
396 ASSERT_EQ(ANeuralNetworksCompilation_finish(compilation), ANEURALNETWORKS_NO_ERROR);
397
398 // close memory
399 ANeuralNetworksCompilation_free(compilation);
400 }
401
createCache()402 void createCache() {
403 sp<CachingDriver> driver =
404 new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache,
405 kNumDataCache, V1_3::ErrorStatus::NONE);
406 compileModel(driver, /*withToken=*/true);
407 }
408
409 static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching";
410 const uint32_t kNumModelCache = std::get<0>(GetParam());
411 const uint32_t kNumDataCache = std::get<1>(GetParam());
412 const V1_3::ErrorStatus kErrorStatusPrepareFromCache = std::get<2>(GetParam());
413 const bool kIsCachingSupported = isCachingSupported(kNumModelCache, kNumDataCache);
414 test_wrapper::Model mModel;
415 std::string mCacheDir;
416 const HalCacheToken kToken{};
417 };
418
TEST_P(CompilationCachingTest,TokenProvidedAndCacheNotExist)419 TEST_P(CompilationCachingTest, TokenProvidedAndCacheNotExist) {
420 if (DeviceManager::get()->getUseCpuOnly()) {
421 return;
422 }
423 sp<CachingDriver> driver =
424 new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache, kNumDataCache,
425 kErrorStatusPrepareFromCache);
426 compileModel(driver, /*withToken=*/true);
427
428 // When cache file does not exist, the runtime should never call prepareModelFromCache_1_3.
429 EXPECT_FALSE(driver->hasCalledPrepareModelFromCache());
430
431 // The runtime should call prepareModel_1_3. It should request caching iff caching supported.
432 EXPECT_EQ(driver->hasCalledPrepareModel(), kIsCachingSupported
433 ? HasCalledPrepareModel::WITH_CACHING
434 : HasCalledPrepareModel::WITHOUT_CACHING);
435 }
436
TEST_P(CompilationCachingTest,TokenProvidedAndCacheExist)437 TEST_P(CompilationCachingTest, TokenProvidedAndCacheExist) {
438 if (DeviceManager::get()->getUseCpuOnly()) {
439 return;
440 }
441 createCache();
442 sp<CachingDriver> driver =
443 new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache, kNumDataCache,
444 kErrorStatusPrepareFromCache);
445 compileModel(driver, /*withToken=*/true);
446
447 // When cache files exist, the runtime should call prepareModelFromCache_1_3 iff caching
448 // supported.
449 EXPECT_EQ(driver->hasCalledPrepareModelFromCache(), kIsCachingSupported);
450
451 HasCalledPrepareModel expectHasCalledPrepareModel;
452 if (kIsCachingSupported) {
453 if (kErrorStatusPrepareFromCache == V1_3::ErrorStatus::NONE) {
454 // The runtime should not call prepareModel_1_3 iff caching supported and
455 // prepareModelFromCache_1_3 succeeds.
456 expectHasCalledPrepareModel = HasCalledPrepareModel::NO;
457 } else {
458 // The runtime should call prepareModel_1_3 and request caching iff caching supported
459 // but prepareModelFromCache_1_3 fails.
460 expectHasCalledPrepareModel = HasCalledPrepareModel::WITH_CACHING;
461 }
462 } else {
463 // The runtime should call prepareModel_1_3 without caching iff caching not supported.
464 expectHasCalledPrepareModel = HasCalledPrepareModel::WITHOUT_CACHING;
465 }
466 EXPECT_EQ(driver->hasCalledPrepareModel(), expectHasCalledPrepareModel);
467 }
468
TEST_P(CompilationCachingTest,TokenNotProvided)469 TEST_P(CompilationCachingTest, TokenNotProvided) {
470 if (DeviceManager::get()->getUseCpuOnly()) {
471 return;
472 }
473 sp<CachingDriver> driver =
474 new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache, kNumDataCache,
475 kErrorStatusPrepareFromCache);
476 compileModel(driver, /*withToken=*/false);
477
478 // When no NDK token is provided by the client, the runtime should never call
479 // prepareModelFromCache_1_3 or request caching with prepareModel_1_3.
480 EXPECT_FALSE(driver->hasCalledPrepareModelFromCache());
481 EXPECT_EQ(driver->hasCalledPrepareModel(), HasCalledPrepareModel::WITHOUT_CACHING);
482 }
483
484 static const auto kErrorStatusGetNumCacheFilesChoices =
485 testing::Values(V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::DEVICE_UNAVAILABLE);
486 static const auto kNumCacheChoices =
487 testing::Values(0ul, 1ul, static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES),
488 static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES) + 1);
489 static const auto kNumValidCacheChoices =
490 testing::Values(0ul, 1ul, static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES));
491 static const auto kErrorStatusPrepareFromCacheChoices =
492 testing::Values(V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::GENERAL_FAILURE,
493 V1_3::ErrorStatus::DEVICE_UNAVAILABLE, V1_3::ErrorStatus::INVALID_ARGUMENT);
494
495 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, DeviceRegistrationTest,
496 testing::Combine(kErrorStatusGetNumCacheFilesChoices, kNumCacheChoices,
497 kNumCacheChoices));
498
499 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, CompilationCachingTest,
500 testing::Combine(kNumValidCacheChoices, kNumValidCacheChoices,
501 kErrorStatusPrepareFromCacheChoices));
502
503 } // namespace
504